Grows the obs contract 32->96 as a SCHEMA change, not a dual hand-rewrite. obs_schema.json v2 adds the channels the trained AI needs: economy aggregates (yields/territory/golden-age), tech/culture/civics, military (army-health, experience, posture, equipment), per-city blocks, terrain summary + biome histogram, richer diplomacy, and the 6-wide CLAN ONE-HOT (the clan-conditioning input; -1 generalist = all-zero). Both interpreters gain the same new ops — onehot, frac (nested operands), histogram, per_entity, plus reduce/sum_len, count_nonnull, truthy, where_any (OR), bool->1.0 coercion in scalar reads. Multi-slot ops emit a run of consecutive slots. OBS_DIM 32->96, OBS_SCHEMA_VERSION 1->2. The contract gate earned its keep: it caught a real Python<->Rust divergence (a stale isinstance(dict) guard zeroed counts over string lists like city.buildings) then confirmed the fix byte-exact. Verified on the DO fleet: learned_encoder_parity green (Rust v2 == Python v2, 56 fixtures incl clan one-hot variety -1..5, zero drift); mc-player-api 188/188. Next: learned:clan-v1 controller wiring, then training on the richer obs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
196 lines
7.5 KiB
Python
196 lines
7.5 KiB
Python
"""Schema-driven observation encoder — the Python half of the shared contract.
|
|
|
|
The observation layout is NOT hardcoded here; it is read from the single source
|
|
of truth `public/games/age-of-dwarves/data/ai/obs_schema.json`. The Rust half
|
|
(`mc-player-api/src/learned/encoder.rs`) interprets the same schema with a
|
|
byte-identical op vocabulary, asserted by `scripts/verify-obs-contract.sh`.
|
|
|
|
See `.project/designs/obs-contract.md`. Paths are PlayerView JSON WIRE keys
|
|
(note serde rename: `UnitView.type_id` -> `"type"`).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from functools import lru_cache
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
_SCHEMA_REL = "public/games/age-of-dwarves/data/ai/obs_schema.json"
|
|
|
|
|
|
@lru_cache(maxsize=4)
|
|
def load_schema(path: str | None = None) -> dict[str, Any]:
|
|
"""Load + lightly validate the obs schema. Cached by path."""
|
|
p = path or os.environ.get("MC_OBS_SCHEMA") or os.path.join(_REPO_ROOT, _SCHEMA_REL)
|
|
with open(p, encoding="utf-8") as fh:
|
|
schema = json.load(fh)
|
|
if int(schema["obs_dim"]) <= 0:
|
|
raise ValueError(f"obs schema: bad obs_dim {schema['obs_dim']!r}")
|
|
if schema.get("normalize") not in ("asinh", None):
|
|
raise ValueError(f"obs schema: unsupported normalize {schema.get('normalize')!r}")
|
|
return schema
|
|
|
|
|
|
def obs_dim(path: str | None = None) -> int:
|
|
return int(load_schema(path)["obs_dim"])
|
|
|
|
|
|
# ── primitives ───────────────────────────────────────────────────────
|
|
def _to_f(v: Any) -> float:
|
|
"""Coerce a JSON value to float, byte-matching the Rust `as_f64`:
|
|
bool -> 1.0/0.0, number -> float, everything else (None/str/list) -> 0.0."""
|
|
if isinstance(v, bool):
|
|
return 1.0 if v else 0.0
|
|
if isinstance(v, (int, float)):
|
|
return float(v)
|
|
return 0.0
|
|
|
|
|
|
def _truthy(v: Any) -> bool:
|
|
"""Match Rust `value_truthy`: bool as-is, number != 0, non-empty string,
|
|
null -> False, other (arrays/objects) -> True."""
|
|
if isinstance(v, bool):
|
|
return v
|
|
if isinstance(v, (int, float)):
|
|
return v != 0
|
|
if isinstance(v, str):
|
|
return len(v) > 0
|
|
if v is None:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _resolve(obj: Any, path: str) -> Any:
|
|
"""Walk a dot-path through nested dicts. Returns None on any miss."""
|
|
cur = obj
|
|
for part in path.split("."):
|
|
if isinstance(cur, dict):
|
|
cur = cur.get(part)
|
|
else:
|
|
return None
|
|
return cur
|
|
|
|
|
|
def _list_at(view: Any, path: str) -> list:
|
|
"""Resolve a (possibly nested) dot-path to a list; [] if missing/not a list."""
|
|
v = _resolve(view, path)
|
|
return v if isinstance(v, list) else []
|
|
|
|
|
|
def _pred(item: dict, p: dict, me: Any) -> bool:
|
|
f = _resolve(item, p["field"])
|
|
if "eq" in p:
|
|
target = me if p["eq"] == "$me" else p["eq"]
|
|
return f == target
|
|
if p.get("truthy"):
|
|
return _truthy(f)
|
|
return True
|
|
|
|
|
|
def _matches(item: dict, where: Any, where_any: Any, me: Any) -> bool:
|
|
if where is None:
|
|
ok = True
|
|
elif isinstance(where, list):
|
|
ok = all(_pred(item, p, me) for p in where) # implicit AND
|
|
else:
|
|
ok = _pred(item, where, me)
|
|
if ok and where_any:
|
|
ok = any(_pred(item, p, me) for p in where_any) # OR group, AND-combined
|
|
return ok
|
|
|
|
|
|
def _eval(view: dict, operand: Any, me: Any) -> float:
|
|
"""Evaluate a scalar operand: a wire-path string, or a nested scalar-width op.
|
|
|
|
A nested op's value is rounded to f32 to byte-match the Rust side, where
|
|
`apply_field` returns `Vec<f32>` (the operand is already f32-rounded before
|
|
the division). String operands stay f64 (Rust reads them via `as_f64`).
|
|
"""
|
|
if isinstance(operand, str):
|
|
return _to_f(_resolve(view, operand))
|
|
return float(np.float32(_apply_field(view, operand, me)[0]))
|
|
|
|
|
|
# ── op interpreter — returns the list of consecutive slot values ───────
|
|
def _apply_field(view: dict, fld: dict, me: Any) -> list[float]:
|
|
op = fld["op"]
|
|
if op == "scalar":
|
|
return [_to_f(_resolve(view, fld["path"]))]
|
|
if op == "clamp_div":
|
|
v = _to_f(_resolve(view, fld["path"]))
|
|
return [min(float(fld["max"]), v / float(fld["divisor"]))]
|
|
if op == "truthy":
|
|
return [1.0 if _truthy(_resolve(view, fld["path"])) else 0.0]
|
|
if op == "count_nonnull":
|
|
return [float(sum(1 for p in fld["fields"] if _resolve(view, p) is not None))]
|
|
if op == "frac":
|
|
n, d = _eval(view, fld["num"], me), _eval(view, fld["den"], me)
|
|
r = (n / d) if d != 0.0 else 0.0
|
|
return [r if np.isfinite(r) else 0.0]
|
|
if op == "onehot":
|
|
size = int(fld["size"])
|
|
out = [0.0] * size
|
|
raw = _resolve(view, fld["path"])
|
|
v = int(raw) if isinstance(raw, (int, float)) and not isinstance(raw, bool) else -1
|
|
if 0 <= v < size:
|
|
out[v] = 1.0
|
|
return out
|
|
if op == "histogram":
|
|
vocab = fld["vocab"]
|
|
out = [0.0] * (len(vocab) + 1)
|
|
idx = {b: k for k, b in enumerate(vocab)}
|
|
for it in _list_at(view, fld["list"]):
|
|
s = it.get(fld["field"]) if isinstance(it, dict) else None
|
|
out[idx.get(s, len(vocab))] += 1.0
|
|
return out
|
|
if op == "per_entity":
|
|
k, stride = int(fld["k"]), int(fld["stride"])
|
|
items = _list_at(view, fld["list"])
|
|
out = [0.0] * (k * stride)
|
|
for e in range(min(k, len(items))):
|
|
ent = items[e]
|
|
for s, sub in enumerate(fld["subfields"]):
|
|
out[e * stride + s] = _apply_field(ent, sub, me)[0]
|
|
return out
|
|
if op == "reduce":
|
|
items = _list_at(view, fld["list"])
|
|
where, where_any = fld.get("where"), fld.get("where_any")
|
|
# No isinstance(dict) guard: reduce may count a list of scalars (e.g.
|
|
# city.buildings = ["forge", ...]); _resolve returns None for non-dict
|
|
# items so field predicates simply don't match. Matches the Rust side.
|
|
filt = [it for it in items if _matches(it, where, where_any, me)]
|
|
contains = fld.get("contains")
|
|
if contains is not None:
|
|
needle, field = contains["needle"], contains["field"]
|
|
return [float(sum(1 for it in filt if needle in str(_resolve(it, field) or "")))]
|
|
agg = fld["agg"]
|
|
if agg == "count":
|
|
return [float(len(filt))]
|
|
sel = fld.get("select")
|
|
if agg == "sum_len":
|
|
return [float(sum(len(_resolve(it, sel) or []) if sel else 0 for it in filt))]
|
|
vals = [_to_f(_resolve(it, sel)) if sel else 1.0 for it in filt]
|
|
if agg == "sum":
|
|
return [float(sum(vals))]
|
|
if agg == "avg":
|
|
return [float(sum(vals) / len(vals)) if vals else 0.0]
|
|
raise ValueError(f"obs schema: unknown agg {agg!r}")
|
|
raise ValueError(f"obs schema: unknown op {op!r}")
|
|
|
|
|
|
def encode_observation(view: dict[str, Any], path: str | None = None) -> np.ndarray:
|
|
"""Project a PlayerView dict into the schema's fixed-shape float32 obs."""
|
|
schema = load_schema(path)
|
|
dim = int(schema["obs_dim"])
|
|
obs = np.zeros(dim, dtype=np.float32)
|
|
me = int(view.get("player", 0))
|
|
for fld in schema["fields"]:
|
|
base = int(fld["i"])
|
|
for j, val in enumerate(_apply_field(view, fld, me)):
|
|
obs[base + j] = val
|
|
if schema.get("normalize") == "asinh":
|
|
obs = np.arcsinh(obs.astype(np.float64)).astype(np.float32)
|
|
return obs
|