feat(tooling): ✨ add turn tracking and forced end turn logic
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
de5fbd42c4
commit
14fbe501ca
1 changed files with 24 additions and 0 deletions
|
|
@ -73,6 +73,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
self._idx_to_action: dict[int, dict[str, Any]] = {}
|
||||
self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool)
|
||||
self._terminated: bool = False
|
||||
self._cur_turn: int = 0
|
||||
self._micro_actions_this_turn: int = 0
|
||||
|
||||
# ── Gymnasium API ────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -97,6 +99,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
)
|
||||
self._client = HarnessClient(cfg)
|
||||
self._terminated = False
|
||||
self._cur_turn = 0
|
||||
self._micro_actions_this_turn = 0
|
||||
view = self._client.view()
|
||||
self._sync_state(view)
|
||||
return encode_observation(view), {"action_mask": self._cur_mask.copy()}
|
||||
|
|
@ -115,6 +119,18 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
idx = 0
|
||||
player_action = decode_action_index(idx, self._idx_to_action)
|
||||
|
||||
# Hard ceiling: if the policy refuses to end its turn after
|
||||
# MAX_MICRO_ACTIONS_PER_TURN, force end_turn. Without this an eval
|
||||
# policy that has learned "ending the turn lowers my reward"
|
||||
# produces an episode of unbounded length.
|
||||
forced_end = False
|
||||
if (
|
||||
self._micro_actions_this_turn >= MAX_MICRO_ACTIONS_PER_TURN
|
||||
and player_action.get("type") != "end_turn"
|
||||
):
|
||||
player_action = {"type": "end_turn"}
|
||||
forced_end = True
|
||||
|
||||
reward = 0.0
|
||||
try:
|
||||
if player_action.get("type") == "end_turn":
|
||||
|
|
@ -134,6 +150,12 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
)
|
||||
|
||||
view = self._client.view()
|
||||
new_turn = int(view.get("turn", 0))
|
||||
if new_turn != self._cur_turn:
|
||||
self._cur_turn = new_turn
|
||||
self._micro_actions_this_turn = 0
|
||||
else:
|
||||
self._micro_actions_this_turn += 1
|
||||
prev_score = self._last_score
|
||||
new_score = float(view.get("score", {}).get("score_estimate", 0.0))
|
||||
reward += SCORE_DELTA_SCALE * (new_score - prev_score)
|
||||
|
|
@ -154,6 +176,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
}
|
||||
if reason:
|
||||
info["reason"] = reason
|
||||
if forced_end:
|
||||
info["forced_end_turn"] = True
|
||||
return encode_observation(view), reward, terminated, truncated, info
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue