feat(tooling): add turn tracking and forced end turn logic

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Natalie 2026-05-17 05:16:18 -07:00
parent de5fbd42c4
commit 14fbe501ca

View file

@ -73,6 +73,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
self._idx_to_action: dict[int, dict[str, Any]] = {}
self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool)
self._terminated: bool = False
self._cur_turn: int = 0
self._micro_actions_this_turn: int = 0
# ── Gymnasium API ────────────────────────────────────────────────
@ -97,6 +99,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
)
self._client = HarnessClient(cfg)
self._terminated = False
self._cur_turn = 0
self._micro_actions_this_turn = 0
view = self._client.view()
self._sync_state(view)
return encode_observation(view), {"action_mask": self._cur_mask.copy()}
@ -115,6 +119,18 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
idx = 0
player_action = decode_action_index(idx, self._idx_to_action)
# Hard ceiling: if the policy refuses to end its turn after
# MAX_MICRO_ACTIONS_PER_TURN, force end_turn. Without this an eval
# policy that has learned "ending the turn lowers my reward"
# produces an episode of unbounded length.
forced_end = False
if (
self._micro_actions_this_turn >= MAX_MICRO_ACTIONS_PER_TURN
and player_action.get("type") != "end_turn"
):
player_action = {"type": "end_turn"}
forced_end = True
reward = 0.0
try:
if player_action.get("type") == "end_turn":
@ -134,6 +150,12 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
)
view = self._client.view()
new_turn = int(view.get("turn", 0))
if new_turn != self._cur_turn:
self._cur_turn = new_turn
self._micro_actions_this_turn = 0
else:
self._micro_actions_this_turn += 1
prev_score = self._last_score
new_score = float(view.get("score", {}).get("score_estimate", 0.0))
reward += SCORE_DELTA_SCALE * (new_score - prev_score)
@ -154,6 +176,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
}
if reason:
info["reason"] = reason
if forced_end:
info["forced_end_turn"] = True
return encode_observation(view), reward, terminated, truncated, info
def close(self) -> None: