feat(@projects/@magic-civilization): ✨ add step_cap evaluation category
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
4a862b76fb
commit
50e174ab06
3 changed files with 46 additions and 31 deletions
|
|
@ -88,9 +88,10 @@ case "$cmd" in
|
|||
launch)
|
||||
remote "
|
||||
cd ${RL_WORKTREE} || exit 1
|
||||
if pgrep -f 'python3 -m tooling.rl_self_play.train' >/dev/null; then
|
||||
existing=\$(ps -eo pid,comm,args | awk '\$2 ~ /^python/ && /rl_self_play.train/ {print \$1}')
|
||||
if [ -n \"\$existing\" ]; then
|
||||
echo 'training already running; run kill first'
|
||||
pgrep -af 'python3 -m tooling.rl_self_play.train'
|
||||
echo \"\$existing\"
|
||||
exit 1
|
||||
fi
|
||||
nohup python3 -m tooling.rl_self_play.train \
|
||||
|
|
@ -100,7 +101,8 @@ case "$cmd" in
|
|||
--run-name ${RL_RUN_NAME} > ${LOG_REMOTE} 2>&1 &
|
||||
echo \$! > ${RL_PIDFILE}
|
||||
sleep 3
|
||||
pgrep -af 'python3 -m tooling.rl_self_play.train' || (echo 'launch failed; check log'; tail -20 ${LOG_REMOTE})
|
||||
ps -eo pid,comm,args | awk '\$2 ~ /^python/ && /rl_self_play.train/' \\
|
||||
|| (echo 'launch failed; check log'; tail -20 ${LOG_REMOTE})
|
||||
"
|
||||
;;
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,11 @@ def _classify_episode(info_history: list[dict[str, object]], total_reward: float
|
|||
return "loss"
|
||||
if reason == "harness_error":
|
||||
return "loss"
|
||||
if reason == "step_cap":
|
||||
# Policy stuck in a no-progress loop and the env truncated the
|
||||
# whole episode — degenerate non-result, surfaced as its own
|
||||
# category so it's visible in the eval JSON.
|
||||
return "step_cap"
|
||||
# No explicit win yet from the env; use score sign as tiebreaker.
|
||||
if total_reward > 0.5:
|
||||
return "win"
|
||||
|
|
@ -66,7 +71,7 @@ def main() -> int:
|
|||
|
||||
model = MaskablePPO.load(str(args.model_path))
|
||||
|
||||
wins = losses = draws = 0
|
||||
wins = losses = draws = step_caps = 0
|
||||
turns_per_episode: list[int] = []
|
||||
for episode in range(args.episodes):
|
||||
cfg = HarnessConfig(
|
||||
|
|
@ -93,6 +98,8 @@ def main() -> int:
|
|||
wins += 1
|
||||
elif verdict == "loss":
|
||||
losses += 1
|
||||
elif verdict == "step_cap":
|
||||
step_caps += 1
|
||||
else:
|
||||
draws += 1
|
||||
turns_per_episode.append(int(info.get("turn", 0)))
|
||||
|
|
@ -106,10 +113,18 @@ def main() -> int:
|
|||
"wins": wins,
|
||||
"losses": losses,
|
||||
"draws": draws,
|
||||
"step_caps": step_caps,
|
||||
"win_rate": wins / total,
|
||||
"mean_turns": round(mean_turns, 1),
|
||||
}
|
||||
print(json.dumps(verdict))
|
||||
if step_caps:
|
||||
print(
|
||||
f"WARNING: {step_caps}/{args.episodes} eval episodes hit the "
|
||||
f"per-episode step cap — policy got stuck in a no-progress "
|
||||
f"loop. Check encoder/reward shaping.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ its win rate against this baseline; the policy is considered to have
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
|
|
@ -66,12 +67,12 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
self,
|
||||
harness_config: HarnessConfig | None = None,
|
||||
max_turns: int = 200,
|
||||
max_micro_actions_per_turn: int = DEFAULT_MAX_MICRO_ACTIONS_PER_TURN,
|
||||
max_steps_per_episode: int = DEFAULT_MAX_STEPS_PER_EPISODE,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._config = harness_config or HarnessConfig()
|
||||
self._max_turns = max_turns
|
||||
self._max_micro_actions_per_turn = max_micro_actions_per_turn
|
||||
self._max_steps_per_episode = max_steps_per_episode
|
||||
self.observation_space = spaces.Box(
|
||||
low=-1e6, high=1e6, shape=(OBS_DIM,), dtype=np.float32
|
||||
)
|
||||
|
|
@ -82,8 +83,7 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
self._idx_to_action: dict[int, dict[str, Any]] = {}
|
||||
self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool)
|
||||
self._terminated: bool = False
|
||||
self._cur_turn: int = 0
|
||||
self._micro_actions_this_turn: int = 0
|
||||
self._step_count: int = 0
|
||||
|
||||
# ── Gymnasium API ────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -108,8 +108,7 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
)
|
||||
self._client = HarnessClient(cfg)
|
||||
self._terminated = False
|
||||
self._cur_turn = 0
|
||||
self._micro_actions_this_turn = 0
|
||||
self._step_count = 0
|
||||
view = self._client.view()
|
||||
self._sync_state(view)
|
||||
return encode_observation(view), {"action_mask": self._cur_mask.copy()}
|
||||
|
|
@ -127,18 +126,7 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
# Mask should prevent this, but be defensive: substitute end_turn.
|
||||
idx = 0
|
||||
player_action = decode_action_index(idx, self._idx_to_action)
|
||||
|
||||
# Hard ceiling: if the policy refuses to end its turn after
|
||||
# MAX_MICRO_ACTIONS_PER_TURN, force end_turn. Without this an eval
|
||||
# policy that has learned "ending the turn lowers my reward"
|
||||
# produces an episode of unbounded length.
|
||||
forced_end = False
|
||||
if (
|
||||
self._micro_actions_this_turn >= MAX_MICRO_ACTIONS_PER_TURN
|
||||
and player_action.get("type") != "end_turn"
|
||||
):
|
||||
player_action = {"type": "end_turn"}
|
||||
forced_end = True
|
||||
self._step_count += 1
|
||||
|
||||
reward = 0.0
|
||||
try:
|
||||
|
|
@ -159,12 +147,6 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
)
|
||||
|
||||
view = self._client.view()
|
||||
new_turn = int(view.get("turn", 0))
|
||||
if new_turn != self._cur_turn:
|
||||
self._cur_turn = new_turn
|
||||
self._micro_actions_this_turn = 0
|
||||
else:
|
||||
self._micro_actions_this_turn += 1
|
||||
prev_score = self._last_score
|
||||
new_score = float(view.get("score", {}).get("score_estimate", 0.0))
|
||||
reward += SCORE_DELTA_SCALE * (new_score - prev_score)
|
||||
|
|
@ -174,7 +156,15 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
self._sync_state(view)
|
||||
self._terminated = terminated
|
||||
|
||||
truncated = (not terminated) and int(view.get("turn", 0)) >= self._max_turns
|
||||
step_capped = (
|
||||
not terminated
|
||||
and self._step_count >= self._max_steps_per_episode
|
||||
)
|
||||
turn_capped = (
|
||||
not terminated
|
||||
and int(view.get("turn", 0)) >= self._max_turns
|
||||
)
|
||||
truncated = step_capped or turn_capped
|
||||
if truncated:
|
||||
self._terminated = True
|
||||
info: dict[str, Any] = {
|
||||
|
|
@ -185,8 +175,16 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
}
|
||||
if reason:
|
||||
info["reason"] = reason
|
||||
if forced_end:
|
||||
info["forced_end_turn"] = True
|
||||
elif step_capped:
|
||||
info["reason"] = "step_cap"
|
||||
print(
|
||||
f"[MagicCivEnv] step_cap hit at step={self._step_count} "
|
||||
f"turn={int(view.get('turn', 0))} — truncating episode",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
elif turn_capped:
|
||||
info["reason"] = "turn_cap"
|
||||
return encode_observation(view), reward, terminated, truncated, info
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue