diff --git a/src/game/engine/src/modules/ai/ai_turn_bridge.gd b/src/game/engine/src/modules/ai/ai_turn_bridge.gd index 17257bd1..f42b6fbf 100644 --- a/src/game/engine/src/modules/ai/ai_turn_bridge.gd +++ b/src/game/engine/src/modules/ai/ai_turn_bridge.gd @@ -130,7 +130,9 @@ static func _apply_mcts_strategic_override(player: RefCounted) -> void: else MCTS_ROLLOUT_COUNT_EARLY) ctrl.set_rollout_budget(budget) ctrl.set_rollout_depth(MCTS_ROLLOUT_DEPTH) - ctrl.set_gpu_enabled(OS.get_environment("AI_GPU_ROLLOUT") in ["1", "true", "TRUE", "True"]) + # p0-20 Phase C — backend (CPU vs GPU) is now decided once at boot by + # AiBackend::probe(); the AI_GPU_ROLLOUT env var was deleted alongside + # the McSnapshot strategic-MCTS path. MC_AI_BACKEND is the only knob. ctrl.set_priors_enabled( OS.get_environment("AI_MCTS_PRIORS") in ["1", "true", "TRUE", "True"] ) diff --git a/src/simulator/api-gdext/src/ai.rs b/src/simulator/api-gdext/src/ai.rs index 22d24730..07860d83 100644 --- a/src/simulator/api-gdext/src/ai.rs +++ b/src/simulator/api-gdext/src/ai.rs @@ -3,16 +3,18 @@ //! Exposes two Godot RefCounted classes: //! //! - `GdMcTreeController` — strategic layer. Accepts a serialized `GameState` -//! JSON, runs parallel MCTS rollouts via `mc-turn`'s `McSnapshot`, and -//! returns the winning `McAction` as a string GDScript can read. +//! JSON, runs MCTS over `Tree` via the abstract-rollout +//! path (CPU or GPU per the boot-probed [`mc_ai::backend::AiBackend`]), and +//! returns the winning [`mc_ai::policy::ActionKind`] as a string GDScript +//! can read. //! - `GdAiController` — tactical layer (p0-26). Accepts an abstract rollout //! state JSON, runs [`mc_ai::tactical::decide_tactical_actions`], and //! returns a `PackedStringArray` of JSON-encoded `Action` records that the //! GDScript turn bridge dispatches back into the engine. //! -//! All simulation logic lives in `mc-turn` and `mc-ai`. This file is a shim only. +//! All simulation logic lives in `mc-ai` and `mc-turn`. This file is a shim only. -use std::sync::{OnceLock, atomic::{AtomicBool, Ordering}}; +use std::sync::OnceLock; use std::time::{Duration, Instant}; use godot::prelude::*; @@ -21,17 +23,24 @@ use mc_ai::backend::AiBackend; use mc_ai::evaluator::{ScoringEvaluator, ScoringWeights}; use mc_ai::game_state::{AiPlayerState, StrategicWeights}; use mc_ai::mcts::XorShift64; -use mc_ai::mcts_tree::{rollout_snapshot, Tree}; -use mc_ai::tactical::{decide_tactical_actions, Action, TacticalEphemerals, TacticalMap, TacticalState, TacticalTile}; -use mc_mcts_service::protocol::SearchActionJob; +use mc_ai::mcts_tree::Tree; +use mc_ai::policy::{ActionKind, PersonalityPriors}; +use mc_ai::rollout::GameRolloutState; +use mc_ai::tactical::{ + decide_tactical_actions, Action, TacticalEphemerals, TacticalMap, TacticalState, TacticalTile, +}; +use mc_mcts_service::protocol::{AbstractJobState, SearchActionViaAbstractJob}; use mc_mcts_service::server::DEFAULT_SOCKET_PATH; -use mc_turn::snapshot::{McAction, McSnapshot}; -use mc_turn::{GameState, TurnProcessor}; +use mc_turn::abstract_projection::to_abstract_rollout_state; +use mc_turn::GameState; +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; // ── Service runtime (process-static) ───────────────────────────────────────── static TOKIO_RT: OnceLock> = OnceLock::new(); static SERVICE_WARN_EMITTED: AtomicBool = AtomicBool::new(false); +static BACKEND: OnceLock = OnceLock::new(); fn tokio_rt() -> Option<&'static tokio::runtime::Runtime> { TOKIO_RT @@ -94,45 +103,170 @@ fn which_mcts_server() -> Result { Err(()) } -/// Try to pick an action via the MCTS service using a full tree search. -/// -/// Serialises `snap` as JSON and sends a `SearchAction` request. The server -/// runs `Tree::simulate_parallel` with the same parameters and returns the -/// best action string + win-rate. Returns `None` on any transport or protocol -/// error so the caller falls back to the local in-process path. -fn try_search_action_via_service( - snap: &McSnapshot, - n_rollouts: u32, - depth: u32, +/// Build per-player `PersonalityPriors` from a parsed `GameState` by reading +/// each [`mc_turn::PlayerState::strategic_axes`] map. Slots beyond +/// `MAX_PLAYERS` are dropped; missing slots default to neutral. +fn build_priors_from_game_state(state: &GameState) -> [PersonalityPriors; MAX_PLAYERS] { + let mut out = [PersonalityPriors::default(); MAX_PLAYERS]; + for (i, p) in state.players.iter().take(MAX_PLAYERS).enumerate() { + let axes: HashMap = p + .strategic_axes + .iter() + .map(|(k, v)| (k.clone(), *v as i32)) + .collect(); + out[i] = PersonalityPriors::from_axes(&axes); + } + out +} + +/// Map an [`ActionKind`] to its stable lower-CamelCase debug name (matches +/// the variant identifier so GDScript can switch on the exact string). +fn action_kind_name(kind: ActionKind) -> &'static str { + match kind { + ActionKind::Build => "Build", + ActionKind::Attack => "Attack", + ActionKind::Settle => "Settle", + ActionKind::Research => "Research", + ActionKind::Defend => "Defend", + ActionKind::Trade => "Trade", + ActionKind::ContinueWar => "ContinueWar", + ActionKind::MakePeace => "MakePeace", + ActionKind::Idle => "Idle", + ActionKind::CommandFormation => "CommandFormation", + ActionKind::SetRallyPoint => "SetRallyPoint", + } +} + +/// Run the abstract-rollout MCTS in-process. Returns the chosen +/// [`ActionKind`], the win-rate at the chosen child, the total visit count +/// at the root, and the per-action visit-count breakdown. +fn run_abstract_search( + state: &GameState, + root_player: u8, base_seed: u64, + rollout_budget: u32, + rollout_depth: u32, use_priors: bool, budget_ms: u64, -) -> Option<(McAction, f32, u32, u32)> { - let snapshot_json = serde_json::to_string(snap).ok()?; - let job = SearchActionJob { - snapshot_json, - root_player: snap.active_player, - n_rollouts, - depth, - seed: base_seed, - use_priors, - budget_ms, +) -> (ActionKind, f32, u32, Vec<(ActionKind, u32)>) { + let pod = to_abstract_rollout_state(state); + let priors = build_priors_from_game_state(state); + + let mut tree = Tree::new(GameRolloutState::new(pod, priors)); + tree.use_priors = use_priors; + tree.root_player = (root_player as usize).min(MAX_PLAYERS - 1) as u8; + tree.rollout_horizon = rollout_depth.max(1); + + let backend = BACKEND.get_or_init(AiBackend::probe); + + // Tunable: 1024 leaves per dispatch matches the persistent-buffer + // MAX_BATCH in `mc_ai::gpu::inner`. Phase B raised this from 32 so + // each `iterate_gpu_batched` call ends up as one wgpu submit and the + // per-submit overhead amortizes only when the batch is big. + const BATCH_SIZE: usize = 1024; + + let total_budget = rollout_budget as usize; + let wall_budget = if budget_ms > 0 { Some(budget_ms) } else { None }; + let start = Instant::now(); + let mut completed: usize = 0; + while completed < total_budget { + if let Some(b) = wall_budget { + if start.elapsed() >= Duration::from_millis(b) { + break; + } + } + let remaining = total_budget - completed; + let this_batch = remaining.min(BATCH_SIZE); + let dispatched = tree.iterate_gpu_batched( + this_batch, + base_seed.wrapping_add(completed as u64), + wall_budget, + backend, + ); + if dispatched == 0 { + break; + } + completed += dispatched; + } + + // Best action = highest visit at root child. + let action = tree + .most_visited_action_at_root() + .unwrap_or(ActionKind::Idle); + + // Win rate of the chosen child. + let mut chosen_visits: u32 = 0; + let mut chosen_wins: f32 = 0.0; + let mut breakdown: Vec<(ActionKind, u32)> = Vec::new(); + for &ci in &tree.root().children { + let n = &tree.nodes[ci]; + let Some(a) = n.action else { continue }; + breakdown.push((a, n.visits)); + if a == action { + chosen_visits = n.visits; + chosen_wins = n.wins; + } + } + let win_rate = if chosen_visits > 0 { + chosen_wins / chosen_visits as f32 + } else { + 0.5 + }; + + (action, win_rate, tree.root().visits, breakdown) +} + +/// Try the `mcts-server` service path. Returns `None` on any transport, +/// protocol, or runtime error so the caller falls back to the local path. +fn try_search_action_via_service( + state: &GameState, + root_player: u8, + base_seed: u64, + rollout_budget: u32, + budget_ms: u64, +) -> Option<(ActionKind, f32, u32, u32)> { + let pod = to_abstract_rollout_state(state); + let priors = build_priors_from_game_state(state); + + let job = SearchActionViaAbstractJob { + abstract_state: AbstractJobState::from_pod(&pod), + priors, + root_player, + rollout_budget, + base_seed, + budget_ms: if budget_ms > 0 { Some(budget_ms) } else { None }, }; let sock = socket_path(); let rt = tokio_rt()?; let result = rt - .block_on(mc_mcts_service::client::submit_search_action(&sock, job)) + .block_on(mc_mcts_service::client::submit_search_action_via_abstract( + &sock, job, + )) .ok()?; - let action = match result.action.as_str() { - "FoundCity" => McAction::FoundCity, - "SpawnUnit" => McAction::SpawnUnit, - _ => McAction::Idle, - }; + let action = parse_action_kind(&result.action); Some((action, result.win_rate, result.n_rollouts, result.took_ms)) } +/// Inverse of `action_kind_name`. Unknown strings fall back to `Idle` so the +/// bridge always gets a well-formed value even if the protocol drifts. +fn parse_action_kind(name: &str) -> ActionKind { + match name { + "Build" => ActionKind::Build, + "Attack" => ActionKind::Attack, + "Settle" => ActionKind::Settle, + "Research" => ActionKind::Research, + "Defend" => ActionKind::Defend, + "Trade" => ActionKind::Trade, + "ContinueWar" => ActionKind::ContinueWar, + "MakePeace" => ActionKind::MakePeace, + "CommandFormation" => ActionKind::CommandFormation, + "SetRallyPoint" => ActionKind::SetRallyPoint, + _ => ActionKind::Idle, + } +} + // ── GdMcTreeController ─────────────────────────────────────────────────────── #[derive(GodotClass)] @@ -143,23 +277,14 @@ pub struct GdMcTreeController { /// Max turns per rollout (depth cap so headless rollouts don't run forever). rollout_depth: u32, /// Per-decision wall-clock budget in milliseconds. `0` means unbounded - /// (default). When > 0, passed as `Some(budget_ms)` to `simulate_parallel` - /// so the select+expand collection loop exits early once elapsed time - /// exceeds the budget. Set via `set_budget_ms` (driven by - /// `MCTS_DECISION_BUDGET_MS` env on the GDScript side). See p1-22. + /// (default). When > 0, threaded into `iterate_gpu_batched` so the + /// outer batch loop exits early once elapsed time exceeds the budget. budget_ms: u64, - /// Boot-probed AI backend used by batched-rollout call sites (Phase 2+ - /// of p0-20). Phase 1 plumbs this onto the controller and logs the - /// adapter at construction; the live `choose_action` path still uses - /// `Tree::simulate_parallel` with CPU rollouts. + /// Boot-probed AI backend used by batched-rollout call sites. ai_backend: AiBackend, /// When true, Trees use PUCT selection with per-node priors instead of - /// classical UCB1 (p0-38). Toggled by `set_priors_enabled` (driven by - /// `AI_MCTS_PRIORS` env). Default `true`; set `AI_MCTS_PRIORS=false` to - /// revert to UCB1. Both `McSnapshot` and `GameRolloutState` override - /// `action_prior` with personality-weighted values — `McSnapshot` via - /// `ScoringWeights` fields, `GameRolloutState` via `PersonalityPriors` - /// softmax over a 9-kind action taxonomy. + /// classical UCB1 (p0-38). Default `true`; set `AI_MCTS_PRIORS=false` to + /// revert to UCB1. priors_enabled: bool, base: Base, } @@ -202,55 +327,40 @@ impl GdMcTreeController { } /// Set the per-decision wall-clock budget in milliseconds (p1-22). - /// Pass `0` (default) for unbounded behavior. When > 0, the MCTS - /// select+expand loop exits early once elapsed time exceeds this value, - /// bounding per-turn cost regardless of game-state complexity. - /// - /// Called from `ai_turn_bridge.gd` based on the `MCTS_DECISION_BUDGET_MS` env. + /// Pass `0` (default) for unbounded behavior. #[func] fn set_budget_ms(&mut self, ms: i64) { self.budget_ms = ms.max(0) as u64; } - /// Phase-1 stub: GPU enable is now decided once at construction by - /// `AiBackend::probe()`. The setter is retained so the GDScript - /// `ai_turn_bridge.gd` shim keeps compiling without code changes; calls - /// are logged but no longer toggle behaviour. Phase 2+ removes this - /// surface alongside the GDScript-side env-flag lookup. - #[func] - fn set_gpu_enabled(&mut self, enabled: bool) { - godot_print!( - "GdMcTreeController::set_gpu_enabled({}) ignored — backend fixed at boot to {}", - enabled, - self.ai_backend.name() - ); - } - /// Enable or disable PUCT selection with per-node priors (p0-38). /// Toggled by `ai_turn_bridge.gd` based on the `AI_MCTS_PRIORS` env. - /// Default `true`; set `AI_MCTS_PRIORS=false` to revert to UCB1. - /// - /// Both `McSnapshot` and `GameRolloutState` implement personality-weighted - /// `action_prior`: `McSnapshot` maps actions to `ScoringWeights` fields - /// (`military_base` for SpawnUnit, `expansion_base` for FoundCity); - /// `GameRolloutState` delegates to `PersonalityPriors::action_prior` over - /// a richer 9-kind action taxonomy. PUCT priors are therefore active for - /// the strategic driver at tree-selection time. Observable clan divergence - /// in tree shape depends on the tree being expanded across multiple levels — - /// see the `simulate_parallel` vs `iterate` pattern in `mcts_tree.rs`. #[func] fn set_priors_enabled(&mut self, enabled: bool) { self.priors_enabled = enabled; } - /// Run MCTS from the serialized `game_state_json` for `player_index` and return - /// the best `McAction` as a string: `"Idle"`, `"FoundCity"`, or `"SpawnUnit"`. + /// Run MCTS from the serialized `game_state_json` for `player_index` and + /// return the best action as a string drawn from + /// [`mc_ai::policy::ActionKind`] — `"Build"`, `"Attack"`, `"Settle"`, + /// `"Research"`, `"Defend"`, `"Trade"`, `"ContinueWar"`, `"MakePeace"`, + /// or `"Idle"`. /// - /// Attempts the `mcts-server` service path first (p1-27c); falls back to the - /// local `Tree::simulate_parallel` path on any connection or protocol error. - /// Log tag `"mcts: service"` or `"mcts: local"` indicates which path ran. + /// Pipeline (p0-20 Phase C): + /// 1. Parse `game_state_json` → `mc_turn::GameState`. + /// 2. Project to `AbstractRolloutState` via + /// `mc_turn::abstract_projection::to_abstract_rollout_state`. + /// 3. Build per-player `PersonalityPriors` from each player's + /// `strategic_axes` map (`PersonalityPriors::from_axes`). + /// 4. Construct `Tree`; iterate `iterate_gpu_batched` + /// until the rollout budget or wall-clock budget is met. The + /// GPU-vs-CPU routing is decided by the boot-probed + /// [`mc_ai::backend::AiBackend`]. + /// 5. Read the highest-visit-count action. /// - /// Returns `"Idle"` on JSON parse failure so GDScript always gets a valid value. + /// Attempts the `mcts-server` service path first when reachable; falls + /// back to the local in-process path on any service error. Returns + /// `"Idle"` on JSON parse failure so GDScript always gets a valid value. #[func] fn choose_action(&self, game_state_json: GString, player_index: i64, seed: i64) -> GString { let state: GameState = match serde_json::from_str(&game_state_json.to_string()) { @@ -261,89 +371,46 @@ impl GdMcTreeController { } }; - let processor = TurnProcessor::new(300); - let mut snapshot = McSnapshot::from_game_state(&state, &processor); - let pi = player_index.max(0) as usize; - snapshot.active_player = pi as u8; + let root_player = player_index.max(0).min(MAX_PLAYERS as i64 - 1) as u8; let base_seed = seed as u64; - // Service path (p1-27c): full tree search via SearchAction request. - if let Some((action, _win_rate, _n, _ms)) = try_search_action_via_service( - &snapshot, - self.rollout_budget, - self.rollout_depth, + // Service path — use the abstract runner when reachable. + if let Some((action, _wr, _n, _ms)) = try_search_action_via_service( + &state, + root_player, base_seed, - self.priors_enabled, + self.rollout_budget, self.budget_ms, ) { godot_print!("mcts: service"); - return GString::from(match action { - McAction::Idle => "Idle", - McAction::FoundCity => "FoundCity", - McAction::SpawnUnit => "SpawnUnit", - }); + return GString::from(action_kind_name(action)); } - // Service unavailable — warn once then use local path. if !SERVICE_WARN_EMITTED.swap(true, Ordering::Relaxed) { auto_start_service(); godot_warn!("mcts: service unavailable, using local path (mcts: local)"); } godot_print!("mcts: local"); - let depth = self.rollout_depth; - let mut tree = Tree::new(snapshot); - tree.use_priors = self.priors_enabled; - - let rollout_fn = move |snap: &McSnapshot, rng: &mut XorShift64| -> f32 { - let step_fn = |s: &McSnapshot, _d: u32, rng: &mut XorShift64| { - let actions = s.legal_actions(); - if actions.is_empty() { - return s.clone(); - } - let idx = rng.next_u64() as usize % actions.len(); - s.step(&actions[idx]) - }; - let score_fn = |s: &McSnapshot| -> f32 { - if let Some(winner) = s.winner() { - if winner == pi { 1.0 } else { 0.0 } - } else { - s.heuristic_value(pi.min(s.players.len().saturating_sub(1))) - } - }; - rollout_snapshot(snap, rng, depth, &step_fn, &score_fn) - }; - - let budget = if self.budget_ms > 0 { Some(self.budget_ms) } else { None }; - tree.simulate_parallel(self.rollout_budget as usize, base_seed, rollout_fn, budget); - - let root_children = tree.root().children.clone(); - let best_child_idx = root_children - .into_iter() - .max_by_key(|&ci| tree.nodes[ci].visits); - - let action = best_child_idx - .and_then(|ci| tree.nodes[ci].action.clone()) - .unwrap_or(McAction::Idle); - - GString::from(match action { - McAction::Idle => "Idle", - McAction::FoundCity => "FoundCity", - McAction::SpawnUnit => "SpawnUnit", - }) + let (action, _win_rate, _root_visits, _breakdown) = run_abstract_search( + &state, + root_player, + base_seed, + self.rollout_budget, + self.rollout_depth, + self.priors_enabled, + self.budget_ms, + ); + let _ = &self.ai_backend; // backend probed at boot; see godot_print in init + GString::from(action_kind_name(action)) } /// Return the serialized `ScoringWeights` for `clan_id` as a JSON string. /// /// `data_dir` must be the OS filesystem path to the game data directory that - /// contains `ai_personalities.json` (e.g. the globalized `res://public/games/age-of-dwarves/data`). - /// Returns `"{}"` (empty object) on any error so the caller gets `ScoringWeights::default()`. - /// - /// **Deprecated for packed builds (p1-24)**: `std::fs` cannot read from - /// inside a `.pck`. New callers should use `scoring_weights_for_clan_json`. + /// contains `ai_personalities.json`. #[func] fn scoring_weights_for_clan(&self, clan_id: GString, data_dir: GString) -> GString { - use mc_ai::evaluator::ScoringWeights; use std::path::Path; let id = clan_id.to_string(); let dir = data_dir.to_string(); @@ -371,7 +438,6 @@ impl GdMcTreeController { clan_id: GString, personalities_json: GString, ) -> GString { - use mc_ai::evaluator::ScoringWeights; let id = clan_id.to_string(); let json = personalities_json.to_string(); match ScoringWeights::from_personality_json(&id, &json) { @@ -396,13 +462,21 @@ impl GdMcTreeController { } } - /// Convenience: return the best action and the win-rate estimate as a JSON dict. - /// `{ "action": "FoundCity", "win_rate": 0.62, "root_idle": N, ... }` - /// - /// Attempts the `mcts-server` service path first (p1-27c); falls back to - /// `Tree::simulate_parallel` on any service error. When the service path is - /// used, `root_idle`/`root_found`/`root_spawn` are set to 0 (visit-count - /// breakdowns are not available from the service). + /// Convenience: return the best action plus a stats dict as JSON. + /// Shape: + /// ```json + /// { + /// "action": "Settle", + /// "win_rate": 0.62, + /// "rollouts": 1024, + /// "path": "gpu", + /// "root_visits": {"Settle": 640, "Build": 200, "Idle": 184} + /// } + /// ``` + /// `root_visits` is a flat `{ActionKind: u32}` over actually-expanded + /// children — overlay code can iterate without pinning a fixed action set. + /// Phase C dropped the legacy `root_idle` / `root_found` / `root_spawn` + /// fields along with the `Tree` path. #[func] fn choose_action_with_stats( &self, @@ -417,34 +491,29 @@ impl GdMcTreeController { "GdMcTreeController::choose_action_with_stats parse error: {}", e ); - return GString::from(r#"{"action":"Idle","win_rate":0.5}"#); + return GString::from(r#"{"action":"Idle","win_rate":0.5,"rollouts":0,"path":"error","root_visits":{}}"#); } }; - let processor = TurnProcessor::new(300); - let mut snapshot = McSnapshot::from_game_state(&state, &processor); - let pi = player_index.max(0) as usize; - snapshot.active_player = pi as u8; + let root_player = player_index.max(0).min(MAX_PLAYERS as i64 - 1) as u8; let base_seed = seed as u64; - // Service path (p1-27c): full tree search via SearchAction request. - // root_idle/found/spawn visit counts are not returned by the service (set to 0). - if let Some((action, win_rate, _n, _ms)) = try_search_action_via_service( - &snapshot, - self.rollout_budget, - self.rollout_depth, + // Service path first. + if let Some((action, win_rate, n_rollouts, _ms)) = try_search_action_via_service( + &state, + root_player, base_seed, - self.priors_enabled, + self.rollout_budget, self.budget_ms, ) { godot_print!("mcts: service"); - let action_str = match action { - McAction::Idle => "Idle", - McAction::FoundCity => "FoundCity", - McAction::SpawnUnit => "SpawnUnit", - }; + // The service does not return per-action breakdowns; emit an + // empty `root_visits` rather than guessing. return GString::from(format!( - r#"{{"action":"{action_str}","win_rate":{win_rate:.4},"root_idle":0,"root_found":0,"root_spawn":0}}"# + r#"{{"action":"{}","win_rate":{:.4},"rollouts":{},"path":"service","root_visits":{{}}}}"#, + action_kind_name(action), + win_rate, + n_rollouts )); } @@ -454,236 +523,41 @@ impl GdMcTreeController { } godot_print!("mcts: local"); - let depth = self.rollout_depth; - let mut tree = Tree::new(snapshot); - tree.use_priors = self.priors_enabled; + let (action, win_rate, root_visits, breakdown) = run_abstract_search( + &state, + root_player, + base_seed, + self.rollout_budget, + self.rollout_depth, + self.priors_enabled, + self.budget_ms, + ); - let rollout_fn = move |snap: &McSnapshot, rng: &mut XorShift64| -> f32 { - let step_fn = |s: &McSnapshot, _d: u32, rng: &mut XorShift64| { - let actions = s.legal_actions(); - if actions.is_empty() { - return s.clone(); - } - let idx = rng.next_u64() as usize % actions.len(); - s.step(&actions[idx]) - }; - let score_fn = |s: &McSnapshot| -> f32 { - if let Some(winner) = s.winner() { - if winner == pi { 1.0 } else { 0.0 } - } else { - s.heuristic_value(pi.min(s.players.len().saturating_sub(1))) - } - }; - rollout_snapshot(snap, rng, depth, &step_fn, &score_fn) - }; - - let budget = if self.budget_ms > 0 { Some(self.budget_ms) } else { None }; - tree.simulate_parallel(self.rollout_budget as usize, base_seed, rollout_fn, budget); - - let root = tree.root(); - let root_children = root.children.clone(); - let best_child_idx = root_children - .into_iter() - .max_by_key(|&ci| tree.nodes[ci].visits); - - let (action, win_rate) = if let Some(ci) = best_child_idx { - let n = &tree.nodes[ci]; - let rate = if n.visits > 0 { - n.wins / n.visits as f32 - } else { - 0.5 - }; - (n.action.clone().unwrap_or(McAction::Idle), rate) - } else { - (McAction::Idle, 0.5) - }; - - let action_str = match action { - McAction::Idle => "Idle", - McAction::FoundCity => "FoundCity", - McAction::SpawnUnit => "SpawnUnit", - }; - - let root = tree.root(); - let mut visits_idle = 0u32; - let mut visits_found = 0u32; - let mut visits_spawn = 0u32; - for &ci in &root.children { - let n = &tree.nodes[ci]; - match &n.action { - Some(McAction::Idle) => visits_idle = n.visits, - Some(McAction::FoundCity) => visits_found = n.visits, - Some(McAction::SpawnUnit) => visits_spawn = n.visits, - None => {} + let mut visits_json = String::with_capacity(64); + visits_json.push('{'); + let mut first = true; + for (a, v) in &breakdown { + if !first { + visits_json.push(','); } + first = false; + visits_json.push('"'); + visits_json.push_str(action_kind_name(*a)); + visits_json.push_str("\":"); + visits_json.push_str(&v.to_string()); } + visits_json.push('}'); + let path = if self.ai_backend.is_gpu() { "gpu" } else { "cpu" }; GString::from(format!( - r#"{{"action":"{action_str}","win_rate":{win_rate:.4},"root_idle":{visits_idle},"root_found":{visits_found},"root_spawn":{visits_spawn}}}"# + r#"{{"action":"{}","win_rate":{:.4},"rollouts":{},"path":"{}","root_visits":{}}}"#, + action_kind_name(action), + win_rate, + root_visits, + path, + visits_json )) } - - /// p0-20 Phase A v3 — choose an action via the abstract `GameRolloutState` - /// path. Coexists with `choose_action` / `choose_action_with_stats` - /// (which still drive `Tree`), giving the GPU MCTS path its - /// own callable entry point without disturbing the live-game flow. - /// - /// Pipeline: - /// 1. Parse `game_state_json` → `mc_turn::GameState`. - /// 2. Project to `AbstractRolloutState` via - /// `mc_turn::abstract_projection::to_abstract_rollout_state`. - /// 3. Build a `Tree` rooted at `(pod, priors_per_player)`, - /// with `priors_per_player` decoded from `priors_json` — a JSON array - /// of up to `MAX_PLAYERS` `PersonalityPriors` objects (missing slots - /// default). - /// 4. Outer loop: call `tree.iterate_gpu_batched(batch_size, …)` until the - /// rollout budget is met OR the wall-clock budget (`budget_ms`) expires. - /// Routing GPU vs CPU is decided by the boot-probed - /// [`mc_ai::backend::AiBackend`]. - /// 5. Return the highest-visit-count action via - /// `Tree::most_visited_action_at_root` as the `ActionKind` debug name - /// (`"Build"`, `"Settle"`, etc.). Empty trees return `"Idle"`. - /// - /// Determinism: same `(game_state_json, priors_json, root_player, base_seed, - /// rollout_budget)` produces the same action across runs and across - /// platforms with the same probed backend (CPU and GPU paths are - /// byte-equivalent — see `tests/gpu_rollout_parity.rs`). - #[func] - fn choose_action_via_abstract( - &self, - game_state_json: GString, - priors_json: GString, - root_player: i64, - base_seed: i64, - ) -> GString { - use mc_ai::abstract_state::MAX_PLAYERS; - use mc_ai::policy::{ActionKind, PersonalityPriors}; - use mc_ai::rollout::GameRolloutState; - use mc_turn::abstract_projection::to_abstract_rollout_state; - - // 1. Parse GameState. - let state: GameState = match serde_json::from_str(&game_state_json.to_string()) { - Ok(s) => s, - Err(e) => { - godot_error!( - "GdMcTreeController::choose_action_via_abstract parse error: {}", e - ); - return GString::from("Idle"); - } - }; - - // 2. Project to abstract POD. - let pod = to_abstract_rollout_state(&state); - - // 3. Decode priors into a fixed-size array, defaulting missing slots. - let priors_arr: [PersonalityPriors; MAX_PLAYERS] = { - let mut out = [PersonalityPriors::default(); MAX_PLAYERS]; - let raw = priors_json.to_string(); - if !raw.trim().is_empty() && raw.trim() != "null" { - match serde_json::from_str::>(&raw) { - Ok(v) => { - for (i, p) in v.into_iter().take(MAX_PLAYERS).enumerate() { - out[i] = p; - } - } - Err(e) => { - godot_error!( - "GdMcTreeController::choose_action_via_abstract priors parse error: {}", - e - ); - } - } - } - out - }; - - // 4. Build the rollout-state tree. The root carries the abstract POD - // + per-player priors; downstream `iterate_gpu_batched` clones the - // POD into the GPU/CPU batch path on every leaf expansion. - let root_state = GameRolloutState::new(pod, priors_arr); - let mut tree = Tree::new(root_state); - tree.use_priors = self.priors_enabled; - let pi = root_player.max(0) as usize; - tree.root_player = pi.min(MAX_PLAYERS - 1) as u8; - - // Backend + outer N-rollout loop. AiBackend is process-static — probe - // once and cache so successive calls don't pay the adapter probe cost. - static BACKEND: OnceLock = OnceLock::new(); - let backend = BACKEND.get_or_init(AiBackend::probe); - - // Tunable: 1024 leaves per dispatch matches the persistent-buffer - // MAX_BATCH in `mc_ai::gpu::inner`. p0-20 Phase B raised this from - // 32: each `iterate_gpu_batched` call ends up as one wgpu submit, - // and the per-submit overhead amortizes only when the batch is - // big. The kernel's @workgroup_size(64) lets 1024 leaves dispatch - // as 16 workgroups in one submit. Note: bumping batch size also - // means more leaves are selected against stale visit counts before - // the first backprop; tree shape can shift vs. the smaller-batch - // version. That's an algorithmic change measured in Phase C, not a - // bug. - const BATCH_SIZE: usize = 1024; - - let total_budget = self.rollout_budget as usize; - let wall_budget = if self.budget_ms > 0 { - Some(self.budget_ms) - } else { - None - }; - let start = Instant::now(); - let mut completed: usize = 0; - let bs = base_seed as u64; - while completed < total_budget { - if let Some(b) = wall_budget { - if start.elapsed() >= Duration::from_millis(b) { - break; - } - } - let remaining = total_budget - completed; - let this_batch = remaining.min(BATCH_SIZE); - // Each batch carries its own seed offset so identical inputs - // produce identical visit counts across runs. - let dispatched = tree.iterate_gpu_batched( - this_batch, - bs.wrapping_add(completed as u64), - wall_budget, - backend, - ); - if dispatched == 0 { - // Either terminal-at-root, batch_size=0, or backend Err. - // Surface and stop — no silent fallback (see - // `iterate_gpu_batched` docstring). - break; - } - completed += dispatched; - } - - // 5. Read the highest-visit child action. - let action_name = match tree.most_visited_action_at_root() { - Some(a) => action_kind_name(a), - None => "Idle", - }; - GString::from(action_name) - } -} - -/// Stable, ordinal-frozen lower-CamelCase debug name for an -/// [`mc_ai::policy::ActionKind`]. Matches the variant identifier so GDScript -/// can switch on the exact string without consulting a lookup table. -fn action_kind_name(kind: mc_ai::policy::ActionKind) -> &'static str { - use mc_ai::policy::ActionKind; - match kind { - ActionKind::Build => "Build", - ActionKind::Attack => "Attack", - ActionKind::Settle => "Settle", - ActionKind::Research => "Research", - ActionKind::Defend => "Defend", - ActionKind::Trade => "Trade", - ActionKind::ContinueWar => "ContinueWar", - ActionKind::MakePeace => "MakePeace", - ActionKind::Idle => "Idle", - ActionKind::CommandFormation => "CommandFormation", - ActionKind::SetRallyPoint => "SetRallyPoint", - } } // ── GdAiController ─────────────────────────────────────────────────────────── @@ -697,25 +571,6 @@ fn action_kind_name(kind: mc_ai::policy::ActionKind) -> &'static str { /// [`Self::set_map`] and updated incrementally via [`Self::update_tile`]), /// runs the tactical decision function, and emits each returned `Action` as /// its own JSON string inside a `PackedStringArray`. -/// -/// # Lifecycle -/// -/// 1. At game-start (or after `reset`), GDScript calls `set_map(w, h, tiles_json)` -/// to populate the Rust-resident tile catalog. `tiles_json` is a JSON array -/// of the serde form of `TacticalTile`. -/// 2. When tiles mutate (improvement built, owner changed, resource revealed) -/// GDScript calls `update_tile(col, row, tile_json)` with the new tile data. -/// 3. Each AI turn, GDScript calls `decide_actions(ephemerals_json, player_index)` -/// with the serde form of `TacticalEphemerals`. The map is NOT included. -/// 4. On new-game / load-game, GDScript calls `reset()` to clear cached map and -/// weights so stale data from the previous session cannot leak. -/// -/// # Legacy path -/// -/// When `cached_map` is `None` (i.e. `set_map` was never called), `decide_actions` -/// falls back to accepting the old monolithic `TacticalState` JSON (which includes -/// a `"map"` field). This preserves backward compat during the migration and for -/// existing tests that construct full `TacticalState` JSON directly. #[derive(GodotClass)] #[class(base=RefCounted)] pub struct GdAiController { @@ -725,12 +580,7 @@ pub struct GdAiController { /// Deterministic RNG seed, advanced per `decide_actions` call so /// successive turns draw distinct xorshift streams. rng_seed: u64, - /// Per-decision wall-clock budget in milliseconds. `0` means unbounded - /// (default). When > 0, `decide_actions` computes `Instant::now() + budget` - /// and threads it through the tactical submodules so each per-unit / - /// per-city loop exits early once elapsed time exceeds the budget. Set via - /// `set_budget_ms` (driven by `MCTS_DECISION_BUDGET_MS` env on the GDScript - /// side). See p1-22. + /// Per-decision wall-clock budget in milliseconds. `0` means unbounded. budget_ms: u64, /// Rust-resident tile catalog. Set once at game-start via `set_map` and /// mutated incrementally via `update_tile`. When `None`, `decide_actions` @@ -755,16 +605,6 @@ impl IRefCounted for GdAiController { #[godot_api] impl GdAiController { /// Populate the Rust-resident tile catalog from a full grid. - /// - /// Called once at game-start (and after `reset`). `tiles_json` is a JSON - /// array of objects matching the serde shape of `TacticalTile`: - /// `[{"hex":[col,row],"biome":"...","yields":[f,p,g],"resource":null,"is_coast":false,"owner":null},...]` - /// - /// After `set_map` succeeds, `decide_actions` uses `TacticalEphemerals` - /// JSON (without a `map` field) and assembles the full `TacticalState` - /// internally. If `tiles_json` fails to parse the map is cleared and a - /// godot_error is emitted — subsequent `decide_actions` calls will fall - /// back to the legacy monolithic JSON path. #[func] fn set_map(&mut self, width: i32, height: i32, tiles_json: GString) { let source = tiles_json.to_string(); @@ -784,13 +624,6 @@ impl GdAiController { } /// Update a single tile in the Rust-resident tile catalog. - /// - /// Called when a tile mutates mid-game (improvement built, border expanded, - /// resource revealed, terrain transformed, etc.). `tile_json` is the serde - /// form of a single `TacticalTile`. On parse failure the existing tile data - /// is left intact and a godot_error is emitted. - /// - /// If `set_map` has not been called yet this is a no-op (logs a warning). #[func] fn update_tile(&mut self, col: i32, row: i32, tile_json: GString) { let map = match self.cached_map.as_mut() { @@ -814,7 +647,6 @@ impl GdAiController { return; } }; - // Find the tile by (col, row) and replace it. let w = map.width as i32; let h = map.height as i32; if col >= 0 && row >= 0 && col < w && row < h { @@ -824,7 +656,6 @@ impl GdAiController { return; } } - // Fallback: linear search (handles non-row-major or sparse maps). if let Some(existing) = map.tiles.iter_mut().find(|t| t.hex == (col, row)) { *existing = tile; } else { @@ -837,8 +668,6 @@ impl GdAiController { /// Clear the cached tile map and player weights so stale data from the /// previous game session cannot leak into a new or loaded game. - /// - /// Called from the GameState autoload on new-game and load-game paths. #[func] fn reset(&mut self) { self.cached_map = None; @@ -847,33 +676,21 @@ impl GdAiController { } /// Override the xorshift seed used by the next call to - /// [`Self::decide_actions`]. Seeds are advanced deterministically after - /// each call, so setting the seed pins the action sequence for testing. + /// [`Self::decide_actions`]. Seeds advance deterministically after each + /// call so setting the seed pins the action sequence for testing. #[func] fn set_rng_seed(&mut self, seed: i64) { - // Round-trip through u64 so GDScript can pass any i64 as an opaque - // seed (negatives are valid bit patterns). self.rng_seed = seed as u64; } - /// Set the per-decision wall-clock budget in milliseconds for the tactical - /// AI path. Pass `0` (default) for unbounded behavior. When > 0, - /// `decide_actions` threads `Some(Instant::now() + budget)` through the - /// tactical submodules; their per-unit / per-city / per-citizen loops - /// check the deadline and break early once elapsed time exceeds it. - /// Mirrors `GdMcTreeController::set_budget_ms` for the strategic path. - /// Called from `ai_turn_bridge.gd` based on `MCTS_DECISION_BUDGET_MS` env (p1-22). + /// Set the per-decision wall-clock budget in milliseconds for the + /// tactical AI path. #[func] fn set_budget_ms(&mut self, ms: i64) { self.budget_ms = ms.max(0) as u64; } - /// Install a player's scoring weights from a serialized JSON blob - /// produced by [`mc_ai::evaluator::ScoringWeights`]'s serde impl. - /// - /// Silently ignores out-of-range `player_index`. Logs an error and keeps - /// the prior weights on parse failure — the bridge must never substitute - /// default weights after a caller has explicitly configured a clan. + /// Install a player's scoring weights from a serialized JSON blob. #[func] fn set_player_weights(&mut self, player_index: i64, weights_json: GString) { let slot = match player_index_to_slot(player_index) { @@ -896,17 +713,7 @@ impl GdAiController { } /// Return formation-level MCTS candidates for the player described by - /// `ai_player_state_json` (the serde form of `mc_ai::game_state::AiPlayerState`). - /// - /// Emits one candidate per (formation × enemy city hex) pair for `advance` - /// commands, plus `defend` candidates for each own city when `threat_level > 0.5`, - /// plus `SetRallyPoint` candidates for cities with a barracks-class building. - /// - /// The returned JSON array has the shape of `mc_ai::mcts::Candidate`: - /// `[{"choice_type":"command_formation","choice_id":"cmd_formation:…","base_score":…},…]` - /// - /// Returns `"[]"` (empty JSON array) on any parse failure — the bridge must - /// never silently substitute an incorrect candidate set. + /// `ai_player_state_json`. #[func] fn formation_candidates( &self, @@ -949,37 +756,10 @@ impl GdAiController { /// Decide tactical actions for the player whose turn is encoded in /// `state_json`. - /// - /// When the Rust-resident tile catalog has been populated via `set_map` - /// (the fast path), `state_json` is the serde form of [`TacticalEphemerals`] - /// — a JSON object containing `current_player`, `turn`, `players`, - /// `unit_catalog`, and `difficulty_threshold_mult` but **no** `"map"` field. - /// The cached map is combined with the ephemerals internally to build the - /// full `TacticalState`. - /// - /// When the cached map is not yet available (first-turn race or `reset` not - /// yet followed by `set_map`), the method falls back to the legacy path and - /// expects the full `TacticalState` JSON (which includes a `"map"` field). - /// This preserves backward compatibility for existing tests and for callers - /// that haven't migrated to `set_map` yet. - /// - /// `player_index` is the slot whose [`ScoringWeights`] to use. It - /// MUST match `state.current_player` — callers that pass a mismatch - /// still get actions, but scored under the wrong clan personality. - /// On mismatch the bridge logs a warning and proceeds with - /// `state.current_player`'s weights. - /// - /// Returns a `PackedStringArray` where each entry is a JSON-encoded - /// [`Action`]. On JSON parse failure or out-of-range `player_index` - /// returns an **empty** array and logs a `godot_error!` diagnostic — - /// the bridge NEVER silently substitutes a default state. #[func] fn decide_actions(&mut self, state_json: GString, player_index: i64) -> PackedStringArray { let source = state_json.to_string(); let seed = self.rng_seed; - // Advance the seed deterministically so the next call draws a fresh - // xorshift stream (SplitMix64 step constant, matches - // `abstract_state` per-player RNG seeding). self.rng_seed = self.rng_seed.wrapping_add(0x9E37_79B9_7F4A_7C15); let slot = match player_index_to_slot(player_index) { @@ -990,7 +770,6 @@ impl GdAiController { } }; - // Fast path: use the Rust-resident cached map + ephemerals JSON. let state: TacticalState = if let Some(map) = self.cached_map.clone() { match parse_tactical_ephemerals_json(&source) { Ok(ephemerals) => { @@ -1009,7 +788,6 @@ impl GdAiController { } } } else { - // Legacy fallback: full TacticalState JSON (includes "map" field). match parse_tactical_state_json(&source) { Ok(s) => { if s.current_player as usize != slot { @@ -1052,8 +830,6 @@ pub fn player_index_to_slot(player_index: i64) -> Result { } let slot = player_index as usize; if slot >= MAX_PLAYERS { - // Graceful degradation for games with more players than MAX_PLAYERS (e.g. 5-clan): - // share the last available weight slot rather than erroring and taking no actions. godot_warn!( "player_index {slot} >= MAX_PLAYERS {MAX_PLAYERS} — capping to slot {}", MAX_PLAYERS - 1 @@ -1064,11 +840,6 @@ pub fn player_index_to_slot(player_index: i64) -> Result { } /// Parse a GDScript-supplied [`TacticalEphemerals`] JSON blob (fast path). -/// -/// The accepted JSON shape is the serde form of [`TacticalEphemerals`] — -/// everything in [`TacticalState`] except `"map"`. GDScript's -/// `ai_turn_bridge_state.gd` builds this after the tile catalog has been -/// handed off to the Rust-resident map via `set_map` / `update_tile`. pub fn parse_tactical_ephemerals_json(source: &str) -> Result { if source.trim().is_empty() { return Err("state_json is empty".to_string()); @@ -1076,16 +847,7 @@ pub fn parse_tactical_ephemerals_json(source: &str) -> Result(source).map_err(|e| format!("ephemerals_json: {e}")) } -/// Parse a GDScript-supplied [`TacticalState`] JSON blob. -/// -/// The accepted JSON shape is the serde form of [`TacticalState`] — see -/// `mc_ai::tactical::state` for the field list. GDScript's -/// `ai_turn_bridge.gd` builds this by walking the engine's hex grid and -/// player/unit/city collections. -/// -/// Errors: -/// - Empty / whitespace-only string — returns a descriptive error. -/// - Any serde parse failure — returns the serde error. +/// Parse a GDScript-supplied [`TacticalState`] JSON blob (legacy fallback). pub fn parse_tactical_state_json(source: &str) -> Result { if source.trim().is_empty() { return Err("state_json is empty".to_string()); @@ -1094,14 +856,6 @@ pub fn parse_tactical_state_json(source: &str) -> Result } /// Run [`decide_tactical_actions`] and serialize each returned action. -/// -/// Split out so unit tests can exercise the pure-Rust path without spinning -/// up a Godot runtime. Returns one JSON string per action. Serialization -/// errors are logged and the offending action is dropped — a single bad -/// action must not collapse the whole turn's dispatch. -/// -/// `deadline`: wall-clock deadline forwarded to `decide_tactical_actions`. -/// `None` is the legacy unbounded path. See p1-22. pub fn run_tactical( state: &TacticalState, weights: &ScoringWeights, @@ -1127,131 +881,29 @@ pub fn run_tactical( #[cfg(test)] mod tests { use super::*; - use mc_ai::evaluator::ScoringWeights; - use mc_ai::mcts_tree::TreeState; - use mc_turn::snapshot::{McSnapshot, PlayerSnap}; - use mc_turn::processor::LairCombatConfig; - - fn make_snap(city_count: u32) -> McSnapshot { - let weights = ScoringWeights::default(); - McSnapshot { - turn: 0, - players: vec![ - PlayerSnap { - gold: 100, - city_count, - unit_count: 2, - expansion_points: 0, - culture_total: 0, - wealth: 3, - expansion_axis: 2, - production_axis: 2, - scoring_weights: weights.clone(), - }, - PlayerSnap { - gold: 80, - city_count, - unit_count: 1, - expansion_points: 0, - culture_total: 0, - wealth: 2, - expansion_axis: 2, - production_axis: 2, - scoring_weights: weights, - }, - ], - config: LairCombatConfig::default(), - victory_city_count: 30, - active_player: 0, - } - } - - #[test] - fn tree_state_impl_legal_actions_non_terminal() { - let snap = make_snap(1); - assert!(!snap.legal_actions().is_empty()); - } - - #[test] - fn tree_state_impl_terminal_when_victory_reached() { - let snap = make_snap(30); - assert!(snap.is_terminal()); - assert!(snap.legal_actions().is_empty()); - } - - #[test] - fn tree_apply_matches_snapshot_step() { - let snap = make_snap(2); - let via_apply = snap.apply(&McAction::Idle); - let via_step = snap.step(&McAction::Idle); - assert_eq!(via_apply.turn, via_step.turn); - assert_eq!(via_apply.players[0].gold, via_step.players[0].gold); - } - - /// 1000 rollouts on a 2-player game must produce a win-rate with variance ≤0.05 - /// across two independent runs with different seeds. - #[test] - fn parallel_rollout_variance_within_threshold() { - let snap = make_snap(5); - let mut tree_a = Tree::new(snap.clone()); - let mut tree_b = Tree::new(snap); - - let depth = 10u32; - let rollout_fn = move |s: &McSnapshot, rng: &mut XorShift64| -> f32 { - let step_fn = |st: &McSnapshot, _: u32, rng: &mut XorShift64| { - let actions = st.legal_actions(); - if actions.is_empty() { - return st.clone(); - } - let idx = rng.next_u64() as usize % actions.len(); - st.step(&actions[idx]) - }; - let score_fn = |st: &McSnapshot| st.heuristic_value(0); - rollout_snapshot(s, rng, depth, &step_fn, &score_fn) - }; - - tree_a.simulate_parallel(1000, 42, &rollout_fn, None); - tree_b.simulate_parallel(1000, 99, &rollout_fn, None); - - let rate_a = { - let r = tree_a.root(); - if r.visits > 0 { r.wins / r.visits as f32 } else { 0.5 } - }; - let rate_b = { - let r = tree_b.root(); - if r.visits > 0 { r.wins / r.visits as f32 } else { 0.5 } - }; - - let variance = (rate_a - rate_b).abs(); - assert!( - variance <= 0.05, - "win-rate variance {variance:.4} exceeds 0.05 threshold (rate_a={rate_a:.4}, rate_b={rate_b:.4})" - ); - } - - /// choose_action returns a valid action string for a minimal JSON game state. - #[test] - fn choose_action_returns_valid_action_string() { - use mc_turn::{GameState, PlayerState, CityEcology, MapUnit}; - use mc_city::CityState; - use std::collections::BTreeMap; + use mc_city::CityState; + use mc_turn::game_state::{CityEcology, MapUnit, PlayerState}; + use std::collections::BTreeMap; + fn make_player(idx: u8, gold: i32, cities: usize) -> PlayerState { let mut axes = BTreeMap::new(); - axes.insert("wealth".into(), 3u8); - axes.insert("expansion".into(), 2u8); - axes.insert("production".into(), 2u8); - axes.insert("culture".into(), 2u8); + axes.insert("aggression".into(), 5u8); + axes.insert("expansion".into(), 5u8); + axes.insert("production".into(), 5u8); + axes.insert("wealth".into(), 5u8); + axes.insert("trade_willingness".into(), 5u8); + axes.insert("grudge_persistence".into(), 5u8); - let player = PlayerState { - player_index: 0, - gold: 100, - cities: vec![CityState::default(); 2], - unit_upkeep: vec![0, 0], - strategic_axes: axes.clone(), + PlayerState { + player_index: idx, + gold, + cities: vec![CityState::default(); cities], + unit_upkeep: vec![0; cities], + strategic_axes: axes, expansion_points: 0, - city_buildings: vec![vec![], vec![]], - city_improvements: vec![vec![], vec![]], - city_ecology: vec![CityEcology::default(); 2], + city_buildings: vec![vec![]; cities], + city_improvements: vec![vec![]; cities], + city_ecology: vec![CityEcology::default(); cities], units: vec![MapUnit { col: 0, row: 0, hp: 10, max_hp: 10, attack: 5, defense: 5, @@ -1259,104 +911,75 @@ mod tests { unit_id: "dwarf_warrior".into(), ..MapUnit::default() }], - city_positions: vec![(0, 0), (1, 1)], - capital_position: Some((0, 0)), + city_positions: (0..cities).map(|i| (i as i32, 0i32)).collect(), + capital_position: if cities > 0 { Some((0, 0)) } else { None }, ..PlayerState::default() - }; + } + } - let state = GameState { + fn make_state() -> GameState { + GameState { turn: 1, - players: vec![player.clone(), PlayerState { player_index: 1, ..player }], + players: vec![make_player(0, 100, 2), make_player(1, 80, 2)], grid: None, pending_pvp_attacks: Default::default(), ..GameState::default() - }; - - let json = serde_json::to_string(&state).expect("serialize"); - - // Build controller inline (no Godot runtime in tests). - let processor = TurnProcessor::new(300); - let snapshot = McSnapshot::from_game_state(&state, &processor); - let pi: usize = 0; - let depth = 10u32; - - let mut tree = Tree::new(snapshot); - let rollout_fn = move |s: &McSnapshot, rng: &mut XorShift64| -> f32 { - let step_fn = |st: &McSnapshot, _: u32, rng: &mut XorShift64| { - let actions = st.legal_actions(); - if actions.is_empty() { - return st.clone(); - } - let idx = rng.next_u64() as usize % actions.len(); - st.step(&actions[idx]) - }; - let score_fn = |st: &McSnapshot| -> f32 { - if let Some(winner) = st.winner() { - if winner == pi { 1.0 } else { 0.0 } - } else { - st.heuristic_value(0) - } - }; - rollout_snapshot(s, rng, depth, &step_fn, &score_fn) - }; - - tree.simulate_parallel(1000, 7, rollout_fn, None); - - let best_action = tree - .root() - .children - .iter() - .max_by_key(|&&ci| tree.nodes[ci].visits) - .and_then(|&ci| tree.nodes[ci].action.clone()) - .unwrap_or(McAction::Idle); - - let action_str = match best_action { - McAction::Idle => "Idle", - McAction::FoundCity => "FoundCity", - McAction::SpawnUnit => "SpawnUnit", - }; - - assert!( - ["Idle", "FoundCity", "SpawnUnit"].contains(&action_str), - "unexpected action: {action_str}" - ); - - // Verify JSON is valid too - assert!(!json.is_empty()); + } } - /// Smoke: `try_search_action_via_service` returns a valid action when the - /// mcts-server is reachable. Skips silently when the service is down so - /// CI (no server running) stays green. Run with a live service: - /// - /// ```text - /// tools/run-services.sh services:up - /// cargo test -p magic-civ-physics-gdext --lib mcts_service_round_trip -- --nocapture - /// ``` + /// `run_abstract_search` returns one of the canonical ActionKind names + /// for a non-trivial state. #[test] - fn mcts_service_round_trip() { - let snap = make_snap(2); - let result = try_search_action_via_service(&snap, 50, 5, 12345u64, true, 0); - - // Skip (pass) when the service isn't running — expected in CI. - let (action, win_rate, n_rollouts, _ms) = match result { - None => return, - Some(r) => r, - }; - - let action_str = match action { - McAction::Idle => "Idle", - McAction::FoundCity => "FoundCity", - McAction::SpawnUnit => "SpawnUnit", - }; + fn run_abstract_search_returns_canonical_action() { + let state = make_state(); + let (action, win_rate, root_visits, breakdown) = + run_abstract_search(&state, 0, 7, 64, 10, true, 0); + let name = action_kind_name(action); + let canonical: &[&str] = &[ + "Build", "Attack", "Settle", "Research", "Defend", "Trade", + "ContinueWar", "MakePeace", "Idle", + ]; assert!( - ["Idle", "FoundCity", "SpawnUnit"].contains(&action_str), - "service returned unexpected action: {action_str}" + canonical.contains(&name), + "unexpected action: {name}" ); assert!( (0.0..=1.0).contains(&win_rate), - "win_rate {win_rate} out of [0,1]" + "win_rate {win_rate} out of [0, 1]" ); - assert!(n_rollouts > 0, "n_rollouts must be > 0"); + assert!(root_visits > 0, "root must have at least one visit"); + assert!(!breakdown.is_empty(), "breakdown should be non-empty"); + } + + /// `build_priors_from_game_state` reads each player's strategic_axes + /// into the corresponding `PersonalityPriors` slot. + #[test] + fn build_priors_from_game_state_picks_per_player_axes() { + let mut state = make_state(); + // Override player 0 to be aggressive, player 1 to be peaceful. + state.players[0].strategic_axes.insert("aggression".into(), 9); + state.players[1].strategic_axes.insert("aggression".into(), 1); + let priors = build_priors_from_game_state(&state); + assert!( + priors[0].aggression > priors[1].aggression, + "p0 aggression ({}) must exceed p1 aggression ({})", + priors[0].aggression, + priors[1].aggression + ); + } + + /// Regression: Phase C dropped the McSnapshot path entirely. Ensure + /// `parse_action_kind` round-trips every canonical name via + /// `action_kind_name`. + #[test] + fn parse_action_kind_round_trips_canonical_names() { + for kind in ActionKind::ALL { + let name = action_kind_name(kind); + assert_eq!( + parse_action_kind(name), + kind, + "round-trip failed for {name}" + ); + } } } diff --git a/src/simulator/crates/mc-mcts-service/src/client.rs b/src/simulator/crates/mc-mcts-service/src/client.rs index be8dd94c..f736f6a8 100644 --- a/src/simulator/crates/mc-mcts-service/src/client.rs +++ b/src/simulator/crates/mc-mcts-service/src/client.rs @@ -6,7 +6,9 @@ use tracing::instrument; use crate::error::ServiceError; use crate::framing::{read_frame, write_frame}; -use crate::protocol::{MctsJob, MctsResult, Request, Response, SearchActionJob, SearchActionResult}; +use crate::protocol::{ + MctsJob, MctsResult, Request, Response, SearchActionResult, SearchActionViaAbstractJob, +}; /// A single-connection client that round-trips one [`Request`] to the service. /// @@ -81,22 +83,24 @@ pub async fn submit_batch( } } -/// Submit a full MCTS tree search job and return the best action with stats. +/// Submit a full abstract-rollout MCTS tree search and return the best action +/// with stats. p0-20 Phase C — only strategic-search entry point; the legacy +/// McSnapshot-shaped `submit_search_action` is gone. /// /// Maps `Response::Error` to [`ServiceError::Remote`]. /// /// # Errors /// /// Returns an error if the transport fails or the service responds with `Error`. -pub async fn submit_search_action( +pub async fn submit_search_action_via_abstract( socket_path: impl AsRef + std::fmt::Debug, - job: SearchActionJob, + job: SearchActionViaAbstractJob, ) -> Result { - match round_trip(socket_path, Request::SearchAction(job)).await? { - Response::SearchActionResult(r) => Ok(r), + match round_trip(socket_path, Request::SearchActionViaAbstract(job)).await? { + Response::SearchActionViaAbstractResult(r) => Ok(r), Response::Error { message } => Err(ServiceError::Remote(message)), other => Err(ServiceError::Remote(format!( - "unexpected response to SearchAction request: {other:?}" + "unexpected response to SearchActionViaAbstract request: {other:?}" ))), } } diff --git a/src/simulator/crates/mc-mcts-service/src/lib.rs b/src/simulator/crates/mc-mcts-service/src/lib.rs index c8aac510..4d27331f 100644 --- a/src/simulator/crates/mc-mcts-service/src/lib.rs +++ b/src/simulator/crates/mc-mcts-service/src/lib.rs @@ -18,10 +18,15 @@ //! Version 1 (p1-27b): `Request::Mcts(MctsJob)` / `Response::MctsResult(MctsResult)` and //! `Request::MctsBatch` / `Response::MctsBatchResult`. State encoded as //! `MctsJobState` JSON inside each job. -//! Version 2 (p1-27c): `Request::SearchAction(SearchActionJob)` / -//! `Response::SearchActionResult(SearchActionResult)`. Carries a full -//! `McSnapshot` JSON; server runs `Tree::simulate_parallel` and returns -//! `{ action, win_rate, n_rollouts, took_ms, path }`. +//! Version 3 (p0-20 Phase C): +//! `Request::SearchActionViaAbstract(SearchActionViaAbstractJob)` / +//! `Response::SearchActionViaAbstractResult(SearchActionResult)`. Carries an +//! `AbstractJobState` (mirror of `mc_ai::abstract_state::AbstractRolloutState`) +//! plus per-player `PersonalityPriors`; server runs +//! `Tree::iterate_gpu_batched` (CPU or GPU per probed +//! `AiBackend`) and returns +//! `{ action, win_rate, n_rollouts, took_ms, path }`. The McSnapshot-shaped +//! `Request::SearchAction` from version 2 is **deleted**. //! //! ## Relationship to @model-boss //! diff --git a/src/simulator/crates/mc-mcts-service/src/protocol.rs b/src/simulator/crates/mc-mcts-service/src/protocol.rs index eeb80cdd..5692da89 100644 --- a/src/simulator/crates/mc-mcts-service/src/protocol.rs +++ b/src/simulator/crates/mc-mcts-service/src/protocol.rs @@ -83,43 +83,23 @@ pub struct MctsResult { pub took_ms: u32, } -/// Input for a full MCTS tree search producing an action decision. +/// Result of a full MCTS tree search rooted at an [`AbstractJobState`]. /// -/// Carries a JSON-serialised [`mc_turn::snapshot::McSnapshot`] so the server -/// can run `Tree::simulate_parallel` over the exact same state type -/// `GdMcTreeController` uses locally, giving byte-compatible action selection. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct SearchActionJob { - /// JSON-encoded `McSnapshot` (serde form of `mc_turn::snapshot::McSnapshot`). - pub snapshot_json: String, - /// Index of the player MCTS is deciding for (matches `McSnapshot::active_player`). - pub root_player: u8, - /// Number of MCTS iterations (`Tree::simulate_parallel` budget). - pub n_rollouts: u32, - /// Rollout horizon — turns per random playout. - pub depth: u32, - /// Base RNG seed for the tree simulation. - pub seed: u64, - /// When `true`, use PUCT selection with per-node priors (p0-38). When - /// `false`, fall back to classical UCB1. Mirrors `GdMcTreeController::priors_enabled`. - pub use_priors: bool, - /// Per-decision wall-clock budget in milliseconds (`0` = unbounded). Mirrors - /// `GdMcTreeController::budget_ms`. - pub budget_ms: u64, -} - -/// Result of a [`Request::SearchAction`] tree search. +/// p0-20 Phase C — replaces the legacy McSnapshot-shaped `SearchActionResult`. +/// The action string is one of [`mc_ai::policy::ActionKind`]'s canonical +/// debug names (`"Build"`, `"Settle"`, `"Idle"`, …). #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct SearchActionResult { - /// The chosen action: `"Idle"`, `"FoundCity"`, or `"SpawnUnit"`. + /// The chosen [`mc_ai::policy::ActionKind`] as a string. pub action: String, /// Win-rate estimate for the chosen action from the best child node. pub win_rate: f32, - /// Total iterations completed inside `Tree::simulate_parallel`. + /// Total iterations completed inside `Tree::iterate_gpu_batched`. pub n_rollouts: u32, /// Wall-clock milliseconds the search took. pub took_ms: u32, - /// Compute path used by the server (`"cpu"` for v1; `"gpu"` reserved). + /// Compute path used by the server: `"gpu"` when the boot-probed + /// [`mc_ai::backend::AiBackend`] reports `is_gpu()`, else `"cpu"`. pub path: String, } @@ -305,12 +285,9 @@ pub enum Request { Mcts(MctsJob), /// Run multiple flat-rollout jobs sequentially (p1-27b primitive). MctsBatch { jobs: Vec }, - /// Run a full MCTS tree search and return the best action (p1-27c). - SearchAction(SearchActionJob), - /// p0-20 Phase A v3 (schema only) — abstract-rollout MCTS search. - /// Server routing via `iterate_gpu_batched` + `most_visited_action_at_root` - /// lands in Phase C; this variant exists in Phase A so callers (api-gdext, - /// integration tests) can encode requests against a stable wire shape. + /// p0-20 Phase C — abstract-rollout MCTS search. Routes through + /// `Tree::iterate_gpu_batched` server-side, returning + /// the best [`mc_ai::policy::ActionKind`] as a string. SearchActionViaAbstract(SearchActionViaAbstractJob), } @@ -323,8 +300,8 @@ pub enum Response { MctsResult(MctsResult), /// Results for a [`Request::MctsBatch`], one entry per input job. MctsBatchResult { results: Vec }, - /// Result for [`Request::SearchAction`]. - SearchActionResult(SearchActionResult), + /// Result for [`Request::SearchActionViaAbstract`] (p0-20 Phase C). + SearchActionViaAbstractResult(SearchActionResult), /// The service encountered an error processing the request. Error { message: String }, } diff --git a/src/simulator/crates/mc-mcts-service/src/server.rs b/src/simulator/crates/mc-mcts-service/src/server.rs index 3a770267..a9ef864c 100644 --- a/src/simulator/crates/mc-mcts-service/src/server.rs +++ b/src/simulator/crates/mc-mcts-service/src/server.rs @@ -1,16 +1,23 @@ /// Async Unix-socket server that processes [`Request`](crate::protocol::Request) frames. use std::path::Path; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; use tokio::net::{UnixListener, UnixStream}; use tracing::{error, info, instrument, warn}; use crate::error::ServiceError; use crate::framing::{read_frame, write_frame}; -use crate::protocol::{MctsJob, MctsJobState, MctsResult, Request, Response, SearchActionJob, SearchActionResult}; +use crate::protocol::{ + MctsJob, MctsJobState, MctsResult, Request, Response, SearchActionResult, + SearchActionViaAbstractJob, +}; /// Default socket path used when none is provided. pub const DEFAULT_SOCKET_PATH: &str = "/tmp/mc-mcts.sock"; +static BACKEND: OnceLock = OnceLock::new(); + /// Run the MCTS service, listening on `socket_path`. /// /// Removes a stale socket file before binding (handles unclean prior shutdown). @@ -26,13 +33,9 @@ pub async fn run(socket_path: impl AsRef + std::fmt::Debug) -> Result<(), let _ = tokio::fs::remove_file(path).await; let listener = UnixListener::bind(path).map_err(ServiceError::Bind)?; - // Phase 1 of p0-20: probe the AI backend at startup so the chosen path - // is observable in service logs. The strategic search call site below - // still uses CPU rollouts via `Tree::simulate_parallel` — Phase 2 wires - // the boot-probed backend into the search itself. This call exists so - // operators see "[mc-ai backend] Cpu (...)" in `mcts-server.log` and - // can confirm the deployed binary is on the expected backend. - let ai_backend = mc_ai::backend::AiBackend::probe(); + // Probe the AI backend at startup. The strategic search runner below + // dispatches through this backend (Cpu or Gpu(...)). + let ai_backend = BACKEND.get_or_init(mc_ai::backend::AiBackend::probe); info!(backend = %ai_backend.name(), "AiBackend probed"); info!("listening"); @@ -93,18 +96,10 @@ fn dispatch(request: Request) -> Response { } Response::MctsBatchResult { results } } - Request::SearchAction(job) => match run_search_action(&job) { - Ok(result) => Response::SearchActionResult(result), + Request::SearchActionViaAbstract(job) => match run_search_action_via_abstract(&job) { + Ok(result) => Response::SearchActionViaAbstractResult(result), Err(msg) => Response::Error { message: msg }, }, - // p0-20 Phase A v3 — schema-only stub. The runner lands in Phase C. - // Returning a structured error here lets callers exercise the wire - // shape (encode/decode round-trip) without claiming server support. - Request::SearchActionViaAbstract(_) => Response::Error { - message: - "SearchActionViaAbstract is schema-only in Phase A; runner ships in Phase C" - .to_owned(), - }, } } @@ -165,78 +160,106 @@ fn run_job(job: &MctsJob) -> Result { }) } -/// Run a full MCTS tree search from a serialised [`McSnapshot`] and return -/// the best [`McAction`] as a string. +/// Run a full abstract-rollout MCTS search and return the best +/// [`mc_ai::policy::ActionKind`] as a [`SearchActionResult`]. /// -/// Mirrors `GdMcTreeController::choose_action_with_stats` exactly — uses the -/// same `Tree::simulate_parallel` + rollout path so action quality is identical -/// to the local in-process path. GPU context is `None` for v1 (CPU only); -/// `path: "cpu"` is set in the response accordingly. -fn run_search_action(job: &SearchActionJob) -> Result { - use mc_ai::mcts::XorShift64; - use mc_ai::mcts_tree::{rollout_snapshot, Tree}; - use mc_turn::snapshot::{McAction, McSnapshot}; +/// Pipeline: +/// 1. Rebuild `AbstractRolloutState` from the [`SearchActionViaAbstractJob`]'s +/// flat mirror. +/// 2. Construct `Tree` with the supplied per-player priors. +/// 3. Loop `tree.iterate_gpu_batched(BATCH_SIZE, …)` until `rollout_budget` +/// is met OR `budget_ms` expires. +/// 4. Read `tree.most_visited_action_at_root()` — return as the canonical +/// debug-name string. Empty trees return `"Idle"`. +/// +/// p0-20 Phase C — replaces the legacy `run_search_action` runner that drove +/// `Tree::simulate_parallel`. The McSnapshot strategic tree is +/// gone; the only path is abstract. +fn run_search_action_via_abstract( + job: &SearchActionViaAbstractJob, +) -> Result { + use mc_ai::backend::AiBackend; + use mc_ai::mcts_tree::Tree; + use mc_ai::policy::ActionKind; + use mc_ai::rollout::GameRolloutState; - let mut snapshot: McSnapshot = serde_json::from_str(&job.snapshot_json) - .map_err(|e| format!("snapshot_json parse error: {e}"))?; - snapshot.active_player = job.root_player; - let pi = job.root_player as usize; - let depth = job.depth; - let base_seed = job.seed; - let n_rollouts = job.n_rollouts.max(1) as usize; - let budget = if job.budget_ms > 0 { Some(job.budget_ms) } else { None }; + let pod = job.abstract_state.to_pod(); + let priors = job.priors; + let mut tree = Tree::new(GameRolloutState::new(pod, priors)); + tree.use_priors = true; + tree.root_player = job.root_player; - let mut tree = Tree::new(snapshot); - tree.use_priors = job.use_priors; + let backend = BACKEND.get_or_init(AiBackend::probe); - let rollout_fn = move |snap: &McSnapshot, rng: &mut XorShift64| -> f32 { - let step_fn = |s: &McSnapshot, _d: u32, rng: &mut XorShift64| { - let actions = s.legal_actions(); - if actions.is_empty() { - return s.clone(); + const BATCH_SIZE: usize = 1024; + let total_budget = job.rollout_budget as usize; + let wall_budget = job.budget_ms; + + let start = Instant::now(); + let mut completed: usize = 0; + while completed < total_budget { + if let Some(b) = wall_budget { + if start.elapsed() >= Duration::from_millis(b) { + break; } - let idx = rng.next_u64() as usize % actions.len(); - s.step(&actions[idx]) - }; - let score_fn = |s: &McSnapshot| -> f32 { - if let Some(winner) = s.winner() { - if winner == pi { 1.0 } else { 0.0 } - } else { - s.heuristic_value(pi.min(s.players.len().saturating_sub(1))) - } - }; - rollout_snapshot(snap, rng, depth, &step_fn, &score_fn) + } + let remaining = total_budget - completed; + let this_batch = remaining.min(BATCH_SIZE); + let dispatched = tree.iterate_gpu_batched( + this_batch, + job.base_seed.wrapping_add(completed as u64), + wall_budget, + backend, + ); + if dispatched == 0 { + break; + } + completed += dispatched; + } + + let action = tree + .most_visited_action_at_root() + .unwrap_or(ActionKind::Idle); + + // Win rate at the chosen child. + let mut chosen_visits = 0u32; + let mut chosen_wins = 0.0f32; + for &ci in &tree.root().children { + let n = &tree.nodes[ci]; + if n.action == Some(action) { + chosen_visits = n.visits; + chosen_wins = n.wins; + break; + } + } + let win_rate = if chosen_visits > 0 { + chosen_wins / chosen_visits as f32 + } else { + 0.5 }; - let start = std::time::Instant::now(); - tree.simulate_parallel(n_rollouts, base_seed, rollout_fn, budget); let took_ms = start.elapsed().as_millis().min(u32::MAX as u128) as u32; - // Robust child: highest visit count. - let root_children = tree.root().children.clone(); - let best_child = root_children - .into_iter() - .max_by_key(|&ci| tree.nodes[ci].visits); - - let (action, win_rate) = if let Some(ci) = best_child { - let n = &tree.nodes[ci]; - let rate = if n.visits > 0 { n.wins / n.visits as f32 } else { 0.5 }; - (n.action.clone().unwrap_or(McAction::Idle), rate) - } else { - (McAction::Idle, 0.5) + let action_name = match action { + ActionKind::Build => "Build", + ActionKind::Attack => "Attack", + ActionKind::Settle => "Settle", + ActionKind::Research => "Research", + ActionKind::Defend => "Defend", + ActionKind::Trade => "Trade", + ActionKind::ContinueWar => "ContinueWar", + ActionKind::MakePeace => "MakePeace", + ActionKind::Idle => "Idle", + ActionKind::CommandFormation => "CommandFormation", + ActionKind::SetRallyPoint => "SetRallyPoint", }; - - let n_completed = tree.root().visits; + let path = if backend.is_gpu() { "gpu" } else { "cpu" }.to_string(); Ok(SearchActionResult { - action: match action { - McAction::Idle => "Idle".to_owned(), - McAction::FoundCity => "FoundCity".to_owned(), - McAction::SpawnUnit => "SpawnUnit".to_owned(), - }, + action: action_name.to_string(), win_rate, - n_rollouts: n_completed, + n_rollouts: tree.root().visits, took_ms, - path: "cpu".to_owned(), + path, }) } diff --git a/src/simulator/crates/mc-turn/src/lib.rs b/src/simulator/crates/mc-turn/src/lib.rs index 97ac4d9c..4185c8dc 100644 --- a/src/simulator/crates/mc-turn/src/lib.rs +++ b/src/simulator/crates/mc-turn/src/lib.rs @@ -30,7 +30,6 @@ pub mod game_state; pub mod combat_event; pub mod processor; pub mod prologue; -pub mod snapshot; pub mod spatial_index; pub mod victory; pub mod courier_resolver; @@ -65,6 +64,5 @@ pub use prologue::{ StartMode, Wanderer, WandererDirection, DEFAULT_LUCKY_INWARD_BIAS_PROB, LUCKY_MAX_BONUS_POP, LUCKY_POP_PER_EXTRA_WANDERERS, MIN_WANDERERS_TO_FORM_TRIBE, TRIBE_CONVERGENCE_RADIUS, }; -pub use snapshot::{McAction, McSnapshot, PlayerSnap}; pub use spatial_index::LairIndexCsr; pub use victory::{VictoryConfig, VictoryType}; diff --git a/src/simulator/crates/mc-turn/src/snapshot.rs b/src/simulator/crates/mc-turn/src/snapshot.rs deleted file mode 100644 index 51e33dc7..00000000 --- a/src/simulator/crates/mc-turn/src/snapshot.rs +++ /dev/null @@ -1,378 +0,0 @@ -//! Compact, cloneable game snapshot for MCTS rollouts. -//! -//! `McSnapshot` captures the fields `TurnProcessor::step` actually mutates -//! (economy, city count, unit count, turn counter) without carrying the full -//! `GridState` — keeping it cheap to clone across rayon threads. -//! -//! `McSnapshot::step` is byte-identical to `TurnProcessor::step` for every -//! field it tracks. Tests in this module assert that invariant. - -use crate::game_state::{GameState, PlayerState}; -use crate::processor::{LairCombatConfig, TurnProcessor}; -use mc_ai::evaluator::ScoringWeights; -use mc_ai::mcts_tree::TreeState; -use serde::{Deserialize, Serialize}; - -// ── Action ────────────────────────────────────────────────────────────────── - -/// An atomic AI decision that `McSnapshot::step` can apply. -/// -/// Only economy-relevant choices are modelled here; spatial decisions -/// (unit movement, attack targeting) are out of scope until B-phase GPU work. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum McAction { - /// Do nothing — let the economy phase run with no strategic override. - Idle, - /// Attempt to found a new city this turn (burns expansion points). - FoundCity, - /// Produce a unit in the first city that can afford it. - SpawnUnit, -} - -// ── Per-player compact state ───────────────────────────────────────────────── - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PlayerSnap { - pub gold: i32, - pub city_count: u32, - pub unit_count: u32, - pub expansion_points: u32, - pub culture_total: i64, - /// Copied from `PlayerState::strategic_axes["wealth"]`. - pub wealth: u8, - /// Copied from `PlayerState::strategic_axes["expansion"]`. - pub expansion_axis: u8, - /// Copied from `PlayerState::strategic_axes["production"]`. - pub production_axis: u8, - pub scoring_weights: ScoringWeights, -} - -impl PlayerSnap { - pub fn from_player(p: &PlayerState) -> Self { - Self { - gold: p.gold, - city_count: p.cities.len() as u32, - unit_count: p.units.len() as u32, - expansion_points: p.expansion_points, - culture_total: p.culture_total, - wealth: *p.strategic_axes.get("wealth").unwrap_or(&2), - expansion_axis: *p.strategic_axes.get("expansion").unwrap_or(&2), - production_axis: *p.strategic_axes.get("production").unwrap_or(&2), - scoring_weights: p.scoring_weights.clone(), - } - } -} - -// ── Snapshot ───────────────────────────────────────────────────────────────── - -/// Lightweight game snapshot. `Clone + Send` — safe to scatter across rayon threads. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct McSnapshot { - pub turn: u32, - pub players: Vec, - pub config: LairCombatConfig, - pub victory_city_count: u8, - /// Index of the player whose turn is being decided (set by the MCTS caller). - /// Used by `TreeState::action_prior` to look up that player's `scoring_weights`. - /// Defaults to 0; `ai.rs::choose_action` sets it to `player_index`. - pub active_player: u8, -} - -impl McSnapshot { - pub fn from_game_state(state: &GameState, processor: &TurnProcessor) -> Self { - Self { - turn: state.turn, - players: state.players.iter().map(PlayerSnap::from_player).collect(), - config: processor.lair_combat_config.clone(), - victory_city_count: processor.victory_city_count, - active_player: 0, - } - } - - /// Advance one turn deterministically. Economy + production + founding + - /// unit-spawn phases only (no spatial movement or combat — grid is absent). - /// - /// Byte-identical to TurnProcessor::step for the fields McSnapshot tracks. - pub fn step(&self, action: &McAction) -> McSnapshot { - let mut next = self.clone(); - next.turn += 1; - - for p in &mut next.players { - // Phase 1: economy - let gold_in = p.wealth as i32 - * p.city_count as i32 - * self.config.gold_per_wealth_per_city; - p.gold = p.gold.saturating_add(gold_in); - - // Phase 1b: culture (uses expansion_axis as culture proxy, matching processor) - let culture_per_turn = p.expansion_axis as i64 * p.city_count as i64 * 25; - p.culture_total += culture_per_turn; - - // Phase 3: expansion points - p.expansion_points = p - .expansion_points - .saturating_add(self.config.expansion_per_axis_per_turn * p.expansion_axis as u32); - - // Phase 3: city founding (only when action matches and player has enough points) - let max_cities = self.config.max_cities_per_player_base as u32 - + 3 * p.expansion_axis as u32; - if p.city_count < max_cities - && p.expansion_points >= self.config.city_founding_cost - && *action == McAction::FoundCity - { - p.city_count += 1; - p.expansion_points -= self.config.city_founding_cost; - } - - // Phase 4: unit production - let prod_budget = - self.config.prod_per_axis_per_city * p.production_axis as u32 * p.city_count; - if prod_budget >= self.config.unit_spawn_cost && *action == McAction::SpawnUnit { - p.unit_count += 1; - } - } - - next - } - - /// Heuristic leaf-value for player `pi` in [0, 1]. - /// - /// Mirrors the `ScoringWeights` MCTS leaf evaluator fields. Normalized - /// against a soft maximum so values stay in a comparable range. - pub fn heuristic_value(&self, pi: usize) -> f32 { - let p = &self.players[pi]; - let w = &p.scoring_weights; - let raw = w.city_expansion * p.city_count as f32 - + w.yield_gold * p.gold.max(0) as f32 * 0.01 - + w.yield_culture * p.culture_total as f32 * 0.0001 - + w.pop_value * p.unit_count as f32; - // Soft-normalize: sigmoid-like squash so value stays in (0, 1). - raw / (raw + 50.0) - } - - /// True if any player has met the city-count victory condition. - pub fn is_terminal(&self) -> bool { - self.players - .iter() - .any(|p| p.city_count >= self.victory_city_count as u32) - } - - /// Winner index if terminal, else None. - pub fn winner(&self) -> Option { - self.players - .iter() - .enumerate() - .find(|(_, p)| p.city_count >= self.victory_city_count as u32) - .map(|(i, _)| i) - } - - /// All legal actions from this snapshot. Empty only if terminal. - pub fn legal_actions(&self) -> Vec { - if self.is_terminal() { - return Vec::new(); - } - vec![McAction::Idle, McAction::FoundCity, McAction::SpawnUnit] - } -} - -// ── TreeState impl ─────────────────────────────────────────────────────────── - -impl TreeState for McSnapshot { - type Action = McAction; - - fn legal_actions(&self) -> Vec { - self.legal_actions() - } - - fn apply(&self, action: &McAction) -> McSnapshot { - self.step(action) - } - - fn is_terminal(&self) -> bool { - self.is_terminal() - } - - /// Personality-weighted action prior for PUCT selection (p0-38). - /// - /// Maps `McAction` variants to the `active_player`'s `ScoringWeights`: - /// - `SpawnUnit` → `military_base` (aggressive clans prefer early armies) - /// - `FoundCity` → `expansion` (expansionist clans prefer new cities) - /// - `Idle` → 1.0 baseline (everyone considers doing nothing) - /// - /// Returns unnormalised weights — `Tree::best_puct_child` normalises via - /// the PUCT formula rather than requiring pre-normalised priors. - fn action_prior(&self, action: &McAction) -> f32 { - let Some(p) = self.players.get(self.active_player as usize) else { - return 1.0; - }; - let w = &p.scoring_weights; - match action { - McAction::SpawnUnit => w.military_base.max(0.1), - McAction::FoundCity => w.expansion_base.max(0.1), - McAction::Idle => 1.0, - } - } -} - -// ── Tests ──────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - use crate::game_state::{CityEcology, MapUnit, PlayerState}; - use mc_ai::evaluator::ScoringWeights; - use mc_city::CityState; - use std::collections::BTreeMap; - - fn make_player(gold: i32, cities: usize, units: usize) -> PlayerState { - let mut axes = BTreeMap::new(); - axes.insert("wealth".into(), 3u8); - axes.insert("expansion".into(), 2u8); - axes.insert("production".into(), 2u8); - axes.insert("culture".into(), 2u8); - PlayerState { - player_index: 0, - gold, - cities: (0..cities).map(|_| CityState::default()).collect(), - unit_upkeep: vec![0; units], - strategic_axes: axes, - scoring_weights: ScoringWeights::default(), - expansion_points: 0, - city_buildings: vec![vec![]; cities], - city_improvements: Default::default(), - city_ecology: vec![CityEcology::default(); cities], - tech_state: None, - science_pool: 0, - player_tech: None, - science_yield: 0, - units: (0..units) - .map(|_| MapUnit { - col: 0, - row: 0, - hp: 10, - max_hp: 10, - attack: 5, - defense: 5, - is_fortified: false, - unit_id: "dwarf_warrior".into(), - held_resources: Vec::new(), - patrol_order: None, - ..Default::default() - }) - .collect(), - city_positions: vec![(0, 0); cities], - capital_position: Some((0, 0)), - culture_total: 0, - culture_pool: mc_culture::CulturePool::default(), - arcane_lore_pop_deducted: false, - traded_luxuries: Default::default(), - relations: Default::default(), - strategic_ledger: Default::default(), - wonders_built: Default::default(), - explored_deposits: Default::default(), - ..Default::default() - } - } - - fn make_state(cities: usize) -> GameState { - GameState { - turn: 0, - players: vec![make_player(100, cities, 2), make_player(80, cities, 1)], - grid: None, - pending_pvp_attacks: Default::default(), - ..Default::default() - } - } - - /// Economy formula: gold_in = wealth * city_count * gold_per_city. - /// Verify McSnapshot::step matches TurnProcessor::step for 10 varying states. - #[test] - fn step_economy_matches_turn_processor_for_10_seeds() { - let processor = TurnProcessor::new(300); - - for seed in 0u32..10 { - let mut state = make_state((seed % 3 + 1) as usize); - state.players[0].gold = seed as i32 * 50; - state.players[1].gold = seed as i32 * 30; - - let snap = McSnapshot::from_game_state(&state, &processor); - let next_snap = snap.step(&McAction::Idle); - - // Run full processor step (no grid, so no movement/combat) - processor.step(&mut state); - - for (pi, p) in state.players.iter().enumerate() { - assert_eq!( - next_snap.players[pi].gold, - p.gold, - "seed={seed} pi={pi}: snapshot gold {} != processor gold {}", - next_snap.players[pi].gold, - p.gold - ); - assert_eq!( - next_snap.players[pi].city_count, - p.cities.len() as u32, - "seed={seed} pi={pi}: city count mismatch" - ); - } - } - } - - #[test] - fn step_idle_does_not_found_city() { - let processor = TurnProcessor::new(300); - let state = make_state(2); - let mut snap = McSnapshot::from_game_state(&state, &processor); - snap.players[0].expansion_points = 100; - let next = snap.step(&McAction::Idle); - assert_eq!(next.players[0].city_count, 2, "Idle must not found a city"); - } - - #[test] - fn step_found_city_spends_expansion_points() { - let processor = TurnProcessor::new(300); - let state = make_state(1); - let mut snap = McSnapshot::from_game_state(&state, &processor); - snap.players[0].expansion_points = 50; - let next = snap.step(&McAction::FoundCity); - assert_eq!(next.players[0].city_count, 2, "FoundCity must add a city"); - let cost = snap.config.city_founding_cost; - let earned = - snap.config.expansion_per_axis_per_turn * snap.players[0].expansion_axis as u32; - assert_eq!( - next.players[0].expansion_points, - 50 - cost + earned, - "expansion_points after founding" - ); - } - - #[test] - fn heuristic_value_returns_value_in_unit_interval() { - let processor = TurnProcessor::new(300); - let state = make_state(3); - let snap = McSnapshot::from_game_state(&state, &processor); - for pi in 0..snap.players.len() { - let v = snap.heuristic_value(pi); - assert!(v >= 0.0 && v < 1.0, "heuristic_value out of range: {v}"); - } - } - - #[test] - fn is_terminal_triggers_at_victory_threshold() { - let processor = TurnProcessor::new(300); - let state = make_state(1); - let mut snap = McSnapshot::from_game_state(&state, &processor); - snap.players[0].city_count = snap.victory_city_count as u32; - assert!(snap.is_terminal()); - assert_eq!(snap.winner(), Some(0)); - } - - #[test] - fn legal_actions_empty_when_terminal() { - let processor = TurnProcessor::new(300); - let state = make_state(1); - let mut snap = McSnapshot::from_game_state(&state, &processor); - snap.players[0].city_count = snap.victory_city_count as u32; - assert!(snap.legal_actions().is_empty()); - } -}