242 lines
10 KiB
Python
242 lines
10 KiB
Python
"""PlayerView ⇄ fixed-shape tensors for RL.
|
||
|
||
The wire-side `PlayerView` is a deeply-nested JSON dict; RL libraries
|
||
need fixed-shape numeric arrays. We pin two contracts here:
|
||
|
||
1. **Observation encoder** (`encode_observation`) projects the view into
|
||
a fixed-length float32 vector. Length is `OBS_DIM`; layout is
|
||
deterministic and documented inline so the policy net can learn a
|
||
stable embedding.
|
||
|
||
2. **Action index encoder** (`encode_legal_actions` /
|
||
`decode_action_index`) flattens the view's `legal_actions` (top-level
|
||
+ per-unit + per-city) into a fixed-size index space `[0, ACTION_DIM)`.
|
||
Indices not occupied by a legal action in the current state are
|
||
masked out by `legal_action_mask`. MaskablePPO consumes that mask
|
||
directly.
|
||
|
||
These encoders are intentionally lossy — they discard tile-by-tile data
|
||
and only summarise the macro state. Replace with a CNN-based observation
|
||
once the macro head proves the loop works end-to-end.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
|
||
# ── Observation shape ────────────────────────────────────────────────
|
||
# The fixed-length observation vector has three blocks:
|
||
# [0:8] self resources + score (gold, gold_per_turn, sci_per_turn,
|
||
# score_estimate, city_count, unit_count,
|
||
# happiness_pool, culture_per_turn)
|
||
# [8:16] self per-turn yields summed across cities (food, production,
|
||
# science, gold, culture)
|
||
# + (avg city pop, total mil units,
|
||
# total founder units)
|
||
# [16:24] opponent intel snapshot (opponent count seen, # at war,
|
||
# # at peace, # open_borders, ...)
|
||
# padded to 8 floats
|
||
# [24:32] turn counters (turn number, fraction of game elapsed,
|
||
# # cities lost, # cities captured,
|
||
# ... pad to 8)
|
||
OBS_DIM = 32
|
||
|
||
# ── Action index layout ──────────────────────────────────────────────
|
||
# We bucket legal actions deterministically:
|
||
# [0] end_turn
|
||
# [1] noop
|
||
# [2..2+MAX_UNITS*K) per-unit slots (skip, fortify, sentry, found_city,
|
||
# move-N/NE/SE/S/SW/NW (6 dirs),
|
||
# attack-target N/NE/SE/S/SW/NW (6 dirs))
|
||
# tail per-city build queue: indices into a fixed
|
||
# priority-ordered roster (worker, warrior, library,
|
||
# barracks, forge, walls, longhouse, monument)
|
||
#
|
||
# Anything legal but outside this layout is silently dropped — the RL
|
||
# agent simply can't learn to take it. For a duel game, the layout
|
||
# below covers >95% of legitimate openings; for the full 5-player
|
||
# huge-map case we extend MAX_UNITS / CITY_QUEUE_SLOTS once the basic
|
||
# loop trains.
|
||
MAX_UNITS = 16
|
||
PER_UNIT_ACTIONS = 16 # skip, fortify, sentry, found, move×6, attack×6, unfortify
|
||
MAX_CITIES = 4
|
||
CITY_QUEUE_ITEMS: tuple[str, ...] = (
|
||
"worker", "warrior", "library", "barracks", "forge",
|
||
"walls", "longhouse", "monument", "dwarf_warrior", "dwarf_founder",
|
||
"spearmen", "archer", "temple", "high_guild_hall", "chronicle_tower",
|
||
"mead_hall",
|
||
)
|
||
CITY_QUEUE_DIM = len(CITY_QUEUE_ITEMS)
|
||
|
||
ACTION_DIM = (
|
||
2 # end_turn, noop
|
||
+ MAX_UNITS * PER_UNIT_ACTIONS
|
||
+ MAX_CITIES * CITY_QUEUE_DIM
|
||
)
|
||
|
||
# Hex neighbour offsets. The world uses an **odd-q** offset layout
|
||
# (flat-top hexes laid out in columns; odd columns shifted down), so a
|
||
# hex's six neighbours are determined by its COLUMN parity. Verified
|
||
# against live `view.units[].legal_actions` move targets (Stage 6.1.6
|
||
# diagnostic). An earlier revision keyed the table on row parity and
|
||
# silently dropped roughly half of every odd-parity unit's legal moves
|
||
# out of the action mask — a latent bug in the RL env's move space.
|
||
# See public/games/age-of-dwarves/docs/HEX_GEOMETRY.md.
|
||
_DIR_OFFSETS_EVEN_COL: tuple[tuple[int, int], ...] = (
|
||
(0, -1), (1, -1), (1, 0), (0, 1), (-1, 0), (-1, -1),
|
||
)
|
||
_DIR_OFFSETS_ODD_COL: tuple[tuple[int, int], ...] = (
|
||
(0, -1), (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0),
|
||
)
|
||
|
||
|
||
def _hex_direction(from_pos: tuple[int, int], to_pos: tuple[int, int]) -> int | None:
|
||
"""Return 0..5 indexing the matching neighbour offset, or None if the
|
||
target is not one of `from_pos`'s six neighbours. odd-q layout — the
|
||
offset table is selected by the *column* parity of `from_pos`."""
|
||
fc, fr = from_pos
|
||
tc, tr = to_pos
|
||
dc, dr = tc - fc, tr - fr
|
||
table = _DIR_OFFSETS_EVEN_COL if (fc % 2 == 0) else _DIR_OFFSETS_ODD_COL
|
||
for i, (odc, odr) in enumerate(table):
|
||
if (odc, odr) == (dc, dr):
|
||
return i
|
||
return None
|
||
|
||
|
||
def encode_observation(view: dict[str, Any]) -> np.ndarray:
|
||
"""Project a PlayerView dict into a fixed-shape float32 vector."""
|
||
obs = np.zeros(OBS_DIM, dtype=np.float32)
|
||
res = view.get("resources", {})
|
||
score = view.get("score", {})
|
||
obs[0] = float(res.get("gold", 0.0))
|
||
obs[1] = float(res.get("gold_per_turn", 0.0))
|
||
obs[2] = float(res.get("science_per_turn", 0.0))
|
||
obs[3] = float(score.get("score_estimate", 0.0))
|
||
obs[4] = float(score.get("city_count", 0.0))
|
||
obs[5] = float(score.get("unit_count", 0.0))
|
||
obs[6] = float(res.get("happiness_pool", 0.0))
|
||
obs[7] = float(res.get("culture_per_turn", 0.0))
|
||
|
||
cities = view.get("cities", [])
|
||
if cities:
|
||
food = sum(float(c.get("yields", {}).get("food", 0)) for c in cities)
|
||
prod = sum(float(c.get("yields", {}).get("production", 0)) for c in cities)
|
||
obs[8] = food
|
||
obs[9] = prod
|
||
obs[10] = sum(float(c.get("population", 0)) for c in cities) / len(cities)
|
||
|
||
units = view.get("units", [])
|
||
me = int(view.get("player", 0))
|
||
my_units = [u for u in units if int(u.get("owner", -1)) == me]
|
||
obs[11] = float(sum(1 for u in my_units if "warrior" in str(u.get("type", ""))))
|
||
obs[12] = float(sum(1 for u in my_units if "founder" in str(u.get("type", ""))))
|
||
|
||
diplo = view.get("diplomacy", [])
|
||
obs[16] = float(len(diplo))
|
||
obs[17] = float(sum(1 for d in diplo if d.get("relation") == "war"))
|
||
obs[18] = float(sum(1 for d in diplo if d.get("relation") == "peace"))
|
||
obs[19] = float(sum(1 for d in diplo if d.get("open_borders")))
|
||
|
||
obs[24] = float(view.get("turn", 0))
|
||
# Bound turn at 1000 (Stage 6.1.5 max_turns) for a rough [0,1] progress signal.
|
||
obs[25] = min(1.0, float(view.get("turn", 0)) / 1000.0)
|
||
return obs
|
||
|
||
|
||
def _unit_action_offset(unit_slot: int, sub: int) -> int:
|
||
return 2 + unit_slot * PER_UNIT_ACTIONS + sub
|
||
|
||
|
||
def _city_action_offset(city_slot: int, item_idx: int) -> int:
|
||
return 2 + MAX_UNITS * PER_UNIT_ACTIONS + city_slot * CITY_QUEUE_DIM + item_idx
|
||
|
||
|
||
def encode_legal_actions(
|
||
view: dict[str, Any],
|
||
) -> tuple[np.ndarray, dict[int, dict[str, Any]]]:
|
||
"""Build the action-mask + an index→PlayerAction lookup table.
|
||
|
||
Returns (mask[ACTION_DIM] bool, idx_to_action dict). Only entries
|
||
present in the returned dict are legal this step; the mask is True
|
||
at those positions. MaskablePPO uses the mask to zero out the
|
||
sampling distribution before drawing.
|
||
"""
|
||
mask = np.zeros(ACTION_DIM, dtype=bool)
|
||
idx_to_action: dict[int, dict[str, Any]] = {}
|
||
|
||
top = view.get("legal_actions", [])
|
||
for entry in top:
|
||
a = entry.get("action", {})
|
||
if a.get("type") == "end_turn":
|
||
mask[0] = True
|
||
idx_to_action[0] = a
|
||
elif a.get("type") == "noop":
|
||
mask[1] = True
|
||
idx_to_action[1] = a
|
||
|
||
units = view.get("units", [])
|
||
me = int(view.get("player", 0))
|
||
my_units = [u for u in units if int(u.get("owner", -1)) == me]
|
||
for slot, u in enumerate(my_units[:MAX_UNITS]):
|
||
upos = tuple(int(x) for x in u.get("position", (0, 0)))
|
||
for entry in u.get("legal_actions", []):
|
||
a = entry.get("action", {})
|
||
sub: int | None = None
|
||
t = a.get("type")
|
||
if t == "skip":
|
||
sub = 0
|
||
elif t == "fortify":
|
||
sub = 1
|
||
elif t == "sentry":
|
||
sub = 2
|
||
elif t == "found_city":
|
||
sub = 3
|
||
elif t == "unfortify":
|
||
sub = 4
|
||
elif t == "move":
|
||
dir_idx = _hex_direction(
|
||
upos, tuple(int(x) for x in a.get("to", (0, 0)))
|
||
)
|
||
if dir_idx is not None:
|
||
sub = 5 + dir_idx # 5..10
|
||
elif t == "attack":
|
||
dir_idx = _hex_direction(
|
||
upos, tuple(int(x) for x in a.get("target", (0, 0)))
|
||
)
|
||
if dir_idx is not None:
|
||
sub = 11 + dir_idx # 11..15 (15 is also unfortify? no, 11..16 but PER_UNIT_ACTIONS=16)
|
||
sub = min(sub, PER_UNIT_ACTIONS - 1)
|
||
if sub is None:
|
||
continue
|
||
offset = _unit_action_offset(slot, sub)
|
||
if offset < ACTION_DIM:
|
||
mask[offset] = True
|
||
idx_to_action[offset] = a
|
||
|
||
cities = view.get("cities", [])
|
||
for slot, c in enumerate(cities[:MAX_CITIES]):
|
||
for entry in c.get("legal_actions", []):
|
||
a = entry.get("action", {})
|
||
if a.get("type") != "queue_production":
|
||
continue
|
||
item = str(a.get("item", ""))
|
||
if item not in CITY_QUEUE_ITEMS:
|
||
continue
|
||
item_idx = CITY_QUEUE_ITEMS.index(item)
|
||
offset = _city_action_offset(slot, item_idx)
|
||
if offset < ACTION_DIM:
|
||
mask[offset] = True
|
||
idx_to_action[offset] = a
|
||
|
||
return mask, idx_to_action
|
||
|
||
|
||
def decode_action_index(
|
||
index: int, idx_to_action: dict[int, dict[str, Any]]
|
||
) -> dict[str, Any]:
|
||
"""Invert `encode_legal_actions`. If the policy picks an index that
|
||
has been masked (shouldn't happen with MaskablePPO, but defensive
|
||
code is cheap), fall back to `end_turn`."""
|
||
return idx_to_action.get(index, {"type": "end_turn"})
|