From 14fbe501cab050572a836f41ee822190a8ff556f Mon Sep 17 00:00:00 2001 From: Natalie Date: Sun, 17 May 2026 05:16:18 -0700 Subject: [PATCH] =?UTF-8?q?feat(tooling):=20=E2=9C=A8=20add=20turn=20track?= =?UTF-8?q?ing=20and=20forced=20end=20turn=20logic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lilith Autocommit --- tooling/rl_self_play/magic_civ_env.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tooling/rl_self_play/magic_civ_env.py b/tooling/rl_self_play/magic_civ_env.py index ff8fcd31..2c51d577 100644 --- a/tooling/rl_self_play/magic_civ_env.py +++ b/tooling/rl_self_play/magic_civ_env.py @@ -73,6 +73,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): self._idx_to_action: dict[int, dict[str, Any]] = {} self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool) self._terminated: bool = False + self._cur_turn: int = 0 + self._micro_actions_this_turn: int = 0 # ── Gymnasium API ──────────────────────────────────────────────── @@ -97,6 +99,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): ) self._client = HarnessClient(cfg) self._terminated = False + self._cur_turn = 0 + self._micro_actions_this_turn = 0 view = self._client.view() self._sync_state(view) return encode_observation(view), {"action_mask": self._cur_mask.copy()} @@ -115,6 +119,18 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): idx = 0 player_action = decode_action_index(idx, self._idx_to_action) + # Hard ceiling: if the policy refuses to end its turn after + # MAX_MICRO_ACTIONS_PER_TURN, force end_turn. Without this an eval + # policy that has learned "ending the turn lowers my reward" + # produces an episode of unbounded length. + forced_end = False + if ( + self._micro_actions_this_turn >= MAX_MICRO_ACTIONS_PER_TURN + and player_action.get("type") != "end_turn" + ): + player_action = {"type": "end_turn"} + forced_end = True + reward = 0.0 try: if player_action.get("type") == "end_turn": @@ -134,6 +150,12 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): ) view = self._client.view() + new_turn = int(view.get("turn", 0)) + if new_turn != self._cur_turn: + self._cur_turn = new_turn + self._micro_actions_this_turn = 0 + else: + self._micro_actions_this_turn += 1 prev_score = self._last_score new_score = float(view.get("score", {}).get("score_estimate", 0.0)) reward += SCORE_DELTA_SCALE * (new_score - prev_score) @@ -154,6 +176,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]): } if reason: info["reason"] = reason + if forced_end: + info["forced_end_turn"] = True return encode_observation(view), reward, terminated, truncated, info def close(self) -> None: