magicciv/tooling/rl-self-play/encoders.py
Natalie ad108810dd feat(@projects/@magic-civilization): add rl-self-play harness and Claude player integration
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-17 03:51:07 -07:00

236 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""PlayerView ⇄ fixed-shape tensors for RL.
The wire-side `PlayerView` is a deeply-nested JSON dict; RL libraries
need fixed-shape numeric arrays. We pin two contracts here:
1. **Observation encoder** (`encode_observation`) projects the view into
a fixed-length float32 vector. Length is `OBS_DIM`; layout is
deterministic and documented inline so the policy net can learn a
stable embedding.
2. **Action index encoder** (`encode_legal_actions` /
`decode_action_index`) flattens the view's `legal_actions` (top-level
+ per-unit + per-city) into a fixed-size index space `[0, ACTION_DIM)`.
Indices not occupied by a legal action in the current state are
masked out by `legal_action_mask`. MaskablePPO consumes that mask
directly.
These encoders are intentionally lossy — they discard tile-by-tile data
and only summarise the macro state. Replace with a CNN-based observation
once the macro head proves the loop works end-to-end.
"""
from __future__ import annotations
from typing import Any
import numpy as np
# ── Observation shape ────────────────────────────────────────────────
# The fixed-length observation vector has three blocks:
# [0:8] self resources + score (gold, gold_per_turn, sci_per_turn,
# score_estimate, city_count, unit_count,
# happiness_pool, culture_per_turn)
# [8:16] self per-turn yields summed across cities (food, production,
# science, gold, culture)
# + (avg city pop, total mil units,
# total founder units)
# [16:24] opponent intel snapshot (opponent count seen, # at war,
# # at peace, # open_borders, ...)
# padded to 8 floats
# [24:32] turn counters (turn number, fraction of game elapsed,
# # cities lost, # cities captured,
# ... pad to 8)
OBS_DIM = 32
# ── Action index layout ──────────────────────────────────────────────
# We bucket legal actions deterministically:
# [0] end_turn
# [1] noop
# [2..2+MAX_UNITS*K) per-unit slots (skip, fortify, sentry, found_city,
# move-N/NE/SE/S/SW/NW (6 dirs),
# attack-target N/NE/SE/S/SW/NW (6 dirs))
# tail per-city build queue: indices into a fixed
# priority-ordered roster (worker, warrior, library,
# barracks, forge, walls, longhouse, monument)
#
# Anything legal but outside this layout is silently dropped — the RL
# agent simply can't learn to take it. For a duel game, the layout
# below covers >95% of legitimate openings; for the full 5-player
# huge-map case we extend MAX_UNITS / CITY_QUEUE_SLOTS once the basic
# loop trains.
MAX_UNITS = 16
PER_UNIT_ACTIONS = 16 # skip, fortify, sentry, found, move×6, attack×6, unfortify
MAX_CITIES = 4
CITY_QUEUE_ITEMS: tuple[str, ...] = (
"worker", "warrior", "library", "barracks", "forge",
"walls", "longhouse", "monument", "dwarf_warrior", "dwarf_founder",
"spearmen", "archer", "temple", "high_guild_hall", "chronicle_tower",
"mead_hall",
)
CITY_QUEUE_DIM = len(CITY_QUEUE_ITEMS)
ACTION_DIM = (
2 # end_turn, noop
+ MAX_UNITS * PER_UNIT_ACTIONS
+ MAX_CITIES * CITY_QUEUE_DIM
)
# Hex-axial direction order (matches `legal_actions` move targets after
# canonicalising relative-direction). Pointy-top, offset coords:
# Even-r layout used by mc-core. Order is N, NE, SE, S, SW, NW.
_DIR_OFFSETS_EVEN: tuple[tuple[int, int], ...] = (
(0, -1), (1, -1), (1, 0), (0, 1), (-1, 0), (-1, -1),
)
_DIR_OFFSETS_ODD: tuple[tuple[int, int], ...] = (
(0, -1), (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0),
)
def _hex_direction(from_pos: tuple[int, int], to_pos: tuple[int, int]) -> int | None:
"""Return 0..5 for the matching cardinal direction, or None if the
target is not adjacent. Even/odd-row aware (offset-r layout)."""
fc, fr = from_pos
tc, tr = to_pos
dc, dr = tc - fc, tr - fr
table = _DIR_OFFSETS_EVEN if (fr % 2 == 0) else _DIR_OFFSETS_ODD
for i, (odc, odr) in enumerate(table):
if (odc, odr) == (dc, dr):
return i
return None
def encode_observation(view: dict[str, Any]) -> np.ndarray:
"""Project a PlayerView dict into a fixed-shape float32 vector."""
obs = np.zeros(OBS_DIM, dtype=np.float32)
res = view.get("resources", {})
score = view.get("score", {})
obs[0] = float(res.get("gold", 0.0))
obs[1] = float(res.get("gold_per_turn", 0.0))
obs[2] = float(res.get("science_per_turn", 0.0))
obs[3] = float(score.get("score_estimate", 0.0))
obs[4] = float(score.get("city_count", 0.0))
obs[5] = float(score.get("unit_count", 0.0))
obs[6] = float(res.get("happiness_pool", 0.0))
obs[7] = float(res.get("culture_per_turn", 0.0))
cities = view.get("cities", [])
if cities:
food = sum(float(c.get("yields", {}).get("food", 0)) for c in cities)
prod = sum(float(c.get("yields", {}).get("production", 0)) for c in cities)
obs[8] = food
obs[9] = prod
obs[10] = sum(float(c.get("population", 0)) for c in cities) / len(cities)
units = view.get("units", [])
me = int(view.get("player", 0))
my_units = [u for u in units if int(u.get("owner", -1)) == me]
obs[11] = float(sum(1 for u in my_units if "warrior" in str(u.get("type", ""))))
obs[12] = float(sum(1 for u in my_units if "founder" in str(u.get("type", ""))))
diplo = view.get("diplomacy", [])
obs[16] = float(len(diplo))
obs[17] = float(sum(1 for d in diplo if d.get("relation") == "war"))
obs[18] = float(sum(1 for d in diplo if d.get("relation") == "peace"))
obs[19] = float(sum(1 for d in diplo if d.get("open_borders")))
obs[24] = float(view.get("turn", 0))
# Bound turn at 500 (huge-map limit) for a rough [0,1] progress signal.
obs[25] = min(1.0, float(view.get("turn", 0)) / 500.0)
return obs
def _unit_action_offset(unit_slot: int, sub: int) -> int:
return 2 + unit_slot * PER_UNIT_ACTIONS + sub
def _city_action_offset(city_slot: int, item_idx: int) -> int:
return 2 + MAX_UNITS * PER_UNIT_ACTIONS + city_slot * CITY_QUEUE_DIM + item_idx
def encode_legal_actions(
view: dict[str, Any],
) -> tuple[np.ndarray, dict[int, dict[str, Any]]]:
"""Build the action-mask + an index→PlayerAction lookup table.
Returns (mask[ACTION_DIM] bool, idx_to_action dict). Only entries
present in the returned dict are legal this step; the mask is True
at those positions. MaskablePPO uses the mask to zero out the
sampling distribution before drawing.
"""
mask = np.zeros(ACTION_DIM, dtype=bool)
idx_to_action: dict[int, dict[str, Any]] = {}
top = view.get("legal_actions", [])
for entry in top:
a = entry.get("action", {})
if a.get("type") == "end_turn":
mask[0] = True
idx_to_action[0] = a
elif a.get("type") == "noop":
mask[1] = True
idx_to_action[1] = a
units = view.get("units", [])
me = int(view.get("player", 0))
my_units = [u for u in units if int(u.get("owner", -1)) == me]
for slot, u in enumerate(my_units[:MAX_UNITS]):
upos = tuple(int(x) for x in u.get("position", (0, 0)))
for entry in u.get("legal_actions", []):
a = entry.get("action", {})
sub: int | None = None
t = a.get("type")
if t == "skip":
sub = 0
elif t == "fortify":
sub = 1
elif t == "sentry":
sub = 2
elif t == "found_city":
sub = 3
elif t == "unfortify":
sub = 4
elif t == "move":
dir_idx = _hex_direction(
upos, tuple(int(x) for x in a.get("to", (0, 0)))
)
if dir_idx is not None:
sub = 5 + dir_idx # 5..10
elif t == "attack":
dir_idx = _hex_direction(
upos, tuple(int(x) for x in a.get("target", (0, 0)))
)
if dir_idx is not None:
sub = 11 + dir_idx # 11..15 (15 is also unfortify? no, 11..16 but PER_UNIT_ACTIONS=16)
sub = min(sub, PER_UNIT_ACTIONS - 1)
if sub is None:
continue
offset = _unit_action_offset(slot, sub)
if offset < ACTION_DIM:
mask[offset] = True
idx_to_action[offset] = a
cities = view.get("cities", [])
for slot, c in enumerate(cities[:MAX_CITIES]):
for entry in c.get("legal_actions", []):
a = entry.get("action", {})
if a.get("type") != "queue_production":
continue
item = str(a.get("item", ""))
if item not in CITY_QUEUE_ITEMS:
continue
item_idx = CITY_QUEUE_ITEMS.index(item)
offset = _city_action_offset(slot, item_idx)
if offset < ACTION_DIM:
mask[offset] = True
idx_to_action[offset] = a
return mask, idx_to_action
def decode_action_index(
index: int, idx_to_action: dict[int, dict[str, Any]]
) -> dict[str, Any]:
"""Invert `encode_legal_actions`. If the policy picks an index that
has been masked (shouldn't happen with MaskablePPO, but defensive
code is cheap), fall back to `end_turn`."""
return idx_to_action.get(index, {"type": "end_turn"})