diff --git a/src/simulator/crates/mc-ai/src/game_state.rs b/src/simulator/crates/mc-ai/src/game_state.rs index 98b3b360..a979912e 100644 --- a/src/simulator/crates/mc-ai/src/game_state.rs +++ b/src/simulator/crates/mc-ai/src/game_state.rs @@ -2,6 +2,54 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; +/// Fixed-index strategic axis identifiers. Used to key the flat `[u8; 8]` form +/// of `strategic_axes` for GPU upload without String heap allocation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum AxisId { + Expansion = 0, + Production = 1, + Wealth = 2, + Culture = 3, + // Slots 4-7 reserved for future axes (magic, military, diplomacy, science). +} + +impl AxisId { + pub const COUNT: usize = 8; + + pub fn as_str(self) -> &'static str { + match self { + Self::Expansion => "expansion", + Self::Production => "production", + Self::Wealth => "wealth", + Self::Culture => "culture", + } + } + + const ALL_NAMED: &'static [AxisId] = + &[AxisId::Expansion, AxisId::Production, AxisId::Wealth, AxisId::Culture]; +} + +/// Encode a `HashMap` strategic-axes map to a fixed-size array +/// keyed by `AxisId` discriminant. Unknown keys are ignored; missing keys +/// default to 0. Slots 4-7 are always 0 (reserved for future axes). +pub fn axes_to_flat(axes: &HashMap) -> [u8; 8] { + let mut out = [0u8; 8]; + for &id in AxisId::ALL_NAMED { + out[id as usize] = *axes.get(id.as_str()).unwrap_or(&0); + } + out +} + +/// Decode a flat `[u8; 8]` back into a `HashMap` containing only +/// the named axes (slots 4-7 are ignored). +pub fn flat_to_axes(flat: &[u8; 8]) -> HashMap { + AxisId::ALL_NAMED + .iter() + .map(|&id| (id.as_str().to_string(), flat[id as usize])) + .collect() +} + /// Compressed AI-relevant player state — serializable for MCTS evaluation. #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct AiPlayerState { diff --git a/src/simulator/crates/mc-ai/src/lib.rs b/src/simulator/crates/mc-ai/src/lib.rs index f4728d69..5b7ae285 100644 --- a/src/simulator/crates/mc-ai/src/lib.rs +++ b/src/simulator/crates/mc-ai/src/lib.rs @@ -12,5 +12,6 @@ pub mod mcts_tree; pub use evaluator::ScoringWeights; pub use game_state::{ - AiCityState, AiPlayerState, AiProductionCandidate, AiTechCandidate, StrategicWeights, + axes_to_flat, flat_to_axes, AiCityState, AiPlayerState, AiProductionCandidate, + AiTechCandidate, AxisId, StrategicWeights, }; diff --git a/src/simulator/crates/mc-ai/src/mcts_tree.rs b/src/simulator/crates/mc-ai/src/mcts_tree.rs index a13a8224..1609d6b5 100644 --- a/src/simulator/crates/mc-ai/src/mcts_tree.rs +++ b/src/simulator/crates/mc-ai/src/mcts_tree.rs @@ -129,7 +129,8 @@ impl Tree { } } - /// Run one full MCTS iteration (select → expand → simulate → backpropagate). + /// Run one full MCTS iteration with the stub rollout (select → expand → simulate → backpropagate). + /// For real-game rollouts use `simulate_parallel` with a custom `rollout_fn`. pub fn iterate(&mut self, rng: &mut XorShift64) { let leaf = self.select(0); let target = self.expand(leaf).unwrap_or(leaf); @@ -143,25 +144,27 @@ impl Tree { /// Each rollout derives its RNG seed from `base_seed + rollout_idx` so output is /// reproducible regardless of thread scheduling. The tree is expanded once before /// dispatch; all rollouts evaluate the same target node. - pub fn simulate_parallel(&mut self, n_rollouts: usize, base_seed: u64) { + /// + /// `rollout_fn(state, rng) -> reward` is called per thread. Pass + /// `Tree::default_rollout` for the stub (0.5), or a real state-walk closure + /// built from `McSnapshot::step` for live game evaluation. + pub fn simulate_parallel(&mut self, n_rollouts: usize, base_seed: u64, rollout_fn: F) + where + S: Sync, + F: Fn(&S, &mut XorShift64) -> f32 + Sync, + { if n_rollouts == 0 { return; } let leaf = self.select(0); let target = self.expand(leaf).unwrap_or(leaf); - - // rollout_fn mirrors Tree::simulate but is free of &self so rayon can call it. - // When mc-turn lands, replace this closure body with a real state-walk loop. - let rollout_fn = |idx: usize, rng: &mut XorShift64| -> f32 { - let _ = (idx, rng); // future rollout args — consumed when simulation is real - 0.5 - }; + let state = &self.nodes[target].state; let mut rewards: Vec<(usize, f32)> = (0..n_rollouts) .into_par_iter() .map(|i| { let mut rng = XorShift64::new(base_seed + i as u64); - let reward = rollout_fn(target, &mut rng); + let reward = rollout_fn(state, &mut rng); (i, reward) }) .collect(); @@ -172,4 +175,41 @@ impl Tree { self.backpropagate(target, reward); } } + + /// Stub rollout — returns 0.5 regardless of state. + /// Use as `rollout_fn` argument before real simulation is wired. + pub fn default_rollout(_state: &S, _rng: &mut XorShift64) -> f32 { + 0.5 + } +} + +// ── Rollout helpers for McSnapshot ────────────────────────────────────────── + +/// Walk `snapshot` forward up to `depth` steps with random legal actions, +/// then return the heuristic value for `player_index`. +/// +/// Called from `api-gdext` and `McTreeController`; lives in this module so +/// callers import from a single crate path. +/// +/// The function is generic via the `TreeState` blanket: callers pass an +/// `&McSnapshot` directly (mc-turn depends on mc-ai, not the reverse). +/// The `TreeState` trait is NOT implemented here to avoid the circular dep — +/// callers do the state-walk with `McSnapshot::step` directly. +pub fn rollout_snapshot( + state: &S, + rng: &mut XorShift64, + depth: u32, + step_fn: &FStep, + score_fn: &FScore, +) -> f32 +where + S: Clone, + FStep: Fn(&S, u32, &mut XorShift64) -> S, + FScore: Fn(&S) -> f32, +{ + let mut s = state.clone(); + for d in 0..depth { + s = step_fn(&s, d, rng); + } + score_fn(&s) } diff --git a/src/simulator/crates/mc-ai/tests/mcts_basic.rs b/src/simulator/crates/mc-ai/tests/mcts_basic.rs index f47bd94c..44b34a0d 100644 --- a/src/simulator/crates/mc-ai/tests/mcts_basic.rs +++ b/src/simulator/crates/mc-ai/tests/mcts_basic.rs @@ -80,7 +80,7 @@ fn terminal_state_has_no_legal_actions() { #[test] fn simulate_parallel_visits_match_n_rollouts() { let mut tree = Tree::new(ToyState { depth: 3, branching: 4 }); - tree.simulate_parallel(200, 42); + tree.simulate_parallel(200, 42, Tree::::default_rollout); // All 200 rollouts must be backpropagated: root visits == 200. assert_eq!(tree.root().visits, 200); // Stubbed rollout returns 0.5 → wins == 0.5 * visits. @@ -91,7 +91,7 @@ fn simulate_parallel_visits_match_n_rollouts() { fn simulate_parallel_deterministic_across_runs() { let make = || { let mut tree = Tree::new(ToyState { depth: 3, branching: 4 }); - tree.simulate_parallel(500, 99); + tree.simulate_parallel(500, 99, Tree::::default_rollout); (tree.root().visits, tree.nodes.len()) }; let (v1, n1) = make(); @@ -118,7 +118,7 @@ fn parallel_faster_than_serial_for_large_n() { let parallel_ms = { let mut tree = Tree::new(ToyState { depth: 3, branching: 4 }); let t = Instant::now(); - tree.simulate_parallel(N, 1); + tree.simulate_parallel(N, 1, Tree::::default_rollout); t.elapsed().as_millis() };