magicciv/tooling/rl_self_play/obs_contract.py
Natalie a6fb75a480
Some checks are pending
ci / regression gate (push) Waiting to run
deploy-next / deploy dev guide to mc.next.black.lan (push) Waiting to run
feat(ai): v2 richer 96-dim clan-conditioned observation (schema data + 4 new ops)
Grows the obs contract 32->96 as a SCHEMA change, not a dual hand-rewrite. obs_schema.json
v2 adds the channels the trained AI needs: economy aggregates (yields/territory/golden-age),
tech/culture/civics, military (army-health, experience, posture, equipment), per-city blocks,
terrain summary + biome histogram, richer diplomacy, and the 6-wide CLAN ONE-HOT (the
clan-conditioning input; -1 generalist = all-zero).

Both interpreters gain the same new ops — onehot, frac (nested operands), histogram,
per_entity, plus reduce/sum_len, count_nonnull, truthy, where_any (OR), bool->1.0 coercion
in scalar reads. Multi-slot ops emit a run of consecutive slots. OBS_DIM 32->96,
OBS_SCHEMA_VERSION 1->2.

The contract gate earned its keep: it caught a real Python<->Rust divergence (a stale
isinstance(dict) guard zeroed counts over string lists like city.buildings) then confirmed
the fix byte-exact. Verified on the DO fleet: learned_encoder_parity green (Rust v2 ==
Python v2, 56 fixtures incl clan one-hot variety -1..5, zero drift); mc-player-api 188/188.

Next: learned:clan-v1 controller wiring, then training on the richer obs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-30 12:38:50 -04:00

196 lines
7.5 KiB
Python

"""Schema-driven observation encoder — the Python half of the shared contract.
The observation layout is NOT hardcoded here; it is read from the single source
of truth `public/games/age-of-dwarves/data/ai/obs_schema.json`. The Rust half
(`mc-player-api/src/learned/encoder.rs`) interprets the same schema with a
byte-identical op vocabulary, asserted by `scripts/verify-obs-contract.sh`.
See `.project/designs/obs-contract.md`. Paths are PlayerView JSON WIRE keys
(note serde rename: `UnitView.type_id` -> `"type"`).
"""
from __future__ import annotations
import json
import os
from functools import lru_cache
from typing import Any
import numpy as np
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
_SCHEMA_REL = "public/games/age-of-dwarves/data/ai/obs_schema.json"
@lru_cache(maxsize=4)
def load_schema(path: str | None = None) -> dict[str, Any]:
"""Load + lightly validate the obs schema. Cached by path."""
p = path or os.environ.get("MC_OBS_SCHEMA") or os.path.join(_REPO_ROOT, _SCHEMA_REL)
with open(p, encoding="utf-8") as fh:
schema = json.load(fh)
if int(schema["obs_dim"]) <= 0:
raise ValueError(f"obs schema: bad obs_dim {schema['obs_dim']!r}")
if schema.get("normalize") not in ("asinh", None):
raise ValueError(f"obs schema: unsupported normalize {schema.get('normalize')!r}")
return schema
def obs_dim(path: str | None = None) -> int:
return int(load_schema(path)["obs_dim"])
# ── primitives ───────────────────────────────────────────────────────
def _to_f(v: Any) -> float:
"""Coerce a JSON value to float, byte-matching the Rust `as_f64`:
bool -> 1.0/0.0, number -> float, everything else (None/str/list) -> 0.0."""
if isinstance(v, bool):
return 1.0 if v else 0.0
if isinstance(v, (int, float)):
return float(v)
return 0.0
def _truthy(v: Any) -> bool:
"""Match Rust `value_truthy`: bool as-is, number != 0, non-empty string,
null -> False, other (arrays/objects) -> True."""
if isinstance(v, bool):
return v
if isinstance(v, (int, float)):
return v != 0
if isinstance(v, str):
return len(v) > 0
if v is None:
return False
return True
def _resolve(obj: Any, path: str) -> Any:
"""Walk a dot-path through nested dicts. Returns None on any miss."""
cur = obj
for part in path.split("."):
if isinstance(cur, dict):
cur = cur.get(part)
else:
return None
return cur
def _list_at(view: Any, path: str) -> list:
"""Resolve a (possibly nested) dot-path to a list; [] if missing/not a list."""
v = _resolve(view, path)
return v if isinstance(v, list) else []
def _pred(item: dict, p: dict, me: Any) -> bool:
f = _resolve(item, p["field"])
if "eq" in p:
target = me if p["eq"] == "$me" else p["eq"]
return f == target
if p.get("truthy"):
return _truthy(f)
return True
def _matches(item: dict, where: Any, where_any: Any, me: Any) -> bool:
if where is None:
ok = True
elif isinstance(where, list):
ok = all(_pred(item, p, me) for p in where) # implicit AND
else:
ok = _pred(item, where, me)
if ok and where_any:
ok = any(_pred(item, p, me) for p in where_any) # OR group, AND-combined
return ok
def _eval(view: dict, operand: Any, me: Any) -> float:
"""Evaluate a scalar operand: a wire-path string, or a nested scalar-width op.
A nested op's value is rounded to f32 to byte-match the Rust side, where
`apply_field` returns `Vec<f32>` (the operand is already f32-rounded before
the division). String operands stay f64 (Rust reads them via `as_f64`).
"""
if isinstance(operand, str):
return _to_f(_resolve(view, operand))
return float(np.float32(_apply_field(view, operand, me)[0]))
# ── op interpreter — returns the list of consecutive slot values ───────
def _apply_field(view: dict, fld: dict, me: Any) -> list[float]:
op = fld["op"]
if op == "scalar":
return [_to_f(_resolve(view, fld["path"]))]
if op == "clamp_div":
v = _to_f(_resolve(view, fld["path"]))
return [min(float(fld["max"]), v / float(fld["divisor"]))]
if op == "truthy":
return [1.0 if _truthy(_resolve(view, fld["path"])) else 0.0]
if op == "count_nonnull":
return [float(sum(1 for p in fld["fields"] if _resolve(view, p) is not None))]
if op == "frac":
n, d = _eval(view, fld["num"], me), _eval(view, fld["den"], me)
r = (n / d) if d != 0.0 else 0.0
return [r if np.isfinite(r) else 0.0]
if op == "onehot":
size = int(fld["size"])
out = [0.0] * size
raw = _resolve(view, fld["path"])
v = int(raw) if isinstance(raw, (int, float)) and not isinstance(raw, bool) else -1
if 0 <= v < size:
out[v] = 1.0
return out
if op == "histogram":
vocab = fld["vocab"]
out = [0.0] * (len(vocab) + 1)
idx = {b: k for k, b in enumerate(vocab)}
for it in _list_at(view, fld["list"]):
s = it.get(fld["field"]) if isinstance(it, dict) else None
out[idx.get(s, len(vocab))] += 1.0
return out
if op == "per_entity":
k, stride = int(fld["k"]), int(fld["stride"])
items = _list_at(view, fld["list"])
out = [0.0] * (k * stride)
for e in range(min(k, len(items))):
ent = items[e]
for s, sub in enumerate(fld["subfields"]):
out[e * stride + s] = _apply_field(ent, sub, me)[0]
return out
if op == "reduce":
items = _list_at(view, fld["list"])
where, where_any = fld.get("where"), fld.get("where_any")
# No isinstance(dict) guard: reduce may count a list of scalars (e.g.
# city.buildings = ["forge", ...]); _resolve returns None for non-dict
# items so field predicates simply don't match. Matches the Rust side.
filt = [it for it in items if _matches(it, where, where_any, me)]
contains = fld.get("contains")
if contains is not None:
needle, field = contains["needle"], contains["field"]
return [float(sum(1 for it in filt if needle in str(_resolve(it, field) or "")))]
agg = fld["agg"]
if agg == "count":
return [float(len(filt))]
sel = fld.get("select")
if agg == "sum_len":
return [float(sum(len(_resolve(it, sel) or []) if sel else 0 for it in filt))]
vals = [_to_f(_resolve(it, sel)) if sel else 1.0 for it in filt]
if agg == "sum":
return [float(sum(vals))]
if agg == "avg":
return [float(sum(vals) / len(vals)) if vals else 0.0]
raise ValueError(f"obs schema: unknown agg {agg!r}")
raise ValueError(f"obs schema: unknown op {op!r}")
def encode_observation(view: dict[str, Any], path: str | None = None) -> np.ndarray:
"""Project a PlayerView dict into the schema's fixed-shape float32 obs."""
schema = load_schema(path)
dim = int(schema["obs_dim"])
obs = np.zeros(dim, dtype=np.float32)
me = int(view.get("player", 0))
for fld in schema["fields"]:
base = int(fld["i"])
for j, val in enumerate(_apply_field(view, fld, me)):
obs[base + j] = val
if schema.get("normalize") == "asinh":
obs = np.arcsinh(obs.astype(np.float64)).astype(np.float32)
return obs