feat(mc-ai): Introduce strategic axis evaluation in game_state and enhance MCTS tree logic for AI decision-making

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-04-16 17:49:18 -07:00
parent a2784dc0b1
commit 163eb0634d
4 changed files with 103 additions and 14 deletions

View file

@ -2,6 +2,54 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// Fixed-index strategic axis identifiers. Used to key the flat `[u8; 8]` form
/// of `strategic_axes` for GPU upload without String heap allocation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum AxisId {
Expansion = 0,
Production = 1,
Wealth = 2,
Culture = 3,
// Slots 4-7 reserved for future axes (magic, military, diplomacy, science).
}
impl AxisId {
pub const COUNT: usize = 8;
pub fn as_str(self) -> &'static str {
match self {
Self::Expansion => "expansion",
Self::Production => "production",
Self::Wealth => "wealth",
Self::Culture => "culture",
}
}
const ALL_NAMED: &'static [AxisId] =
&[AxisId::Expansion, AxisId::Production, AxisId::Wealth, AxisId::Culture];
}
/// Encode a `HashMap<String, u8>` strategic-axes map to a fixed-size array
/// keyed by `AxisId` discriminant. Unknown keys are ignored; missing keys
/// default to 0. Slots 4-7 are always 0 (reserved for future axes).
pub fn axes_to_flat(axes: &HashMap<String, u8>) -> [u8; 8] {
let mut out = [0u8; 8];
for &id in AxisId::ALL_NAMED {
out[id as usize] = *axes.get(id.as_str()).unwrap_or(&0);
}
out
}
/// Decode a flat `[u8; 8]` back into a `HashMap<String, u8>` containing only
/// the named axes (slots 4-7 are ignored).
pub fn flat_to_axes(flat: &[u8; 8]) -> HashMap<String, u8> {
AxisId::ALL_NAMED
.iter()
.map(|&id| (id.as_str().to_string(), flat[id as usize]))
.collect()
}
/// Compressed AI-relevant player state — serializable for MCTS evaluation.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AiPlayerState {

View file

@ -12,5 +12,6 @@ pub mod mcts_tree;
pub use evaluator::ScoringWeights;
pub use game_state::{
AiCityState, AiPlayerState, AiProductionCandidate, AiTechCandidate, StrategicWeights,
axes_to_flat, flat_to_axes, AiCityState, AiPlayerState, AiProductionCandidate,
AiTechCandidate, AxisId, StrategicWeights,
};

View file

@ -129,7 +129,8 @@ impl<S: TreeState> Tree<S> {
}
}
/// Run one full MCTS iteration (select → expand → simulate → backpropagate).
/// Run one full MCTS iteration with the stub rollout (select → expand → simulate → backpropagate).
/// For real-game rollouts use `simulate_parallel` with a custom `rollout_fn`.
pub fn iterate(&mut self, rng: &mut XorShift64) {
let leaf = self.select(0);
let target = self.expand(leaf).unwrap_or(leaf);
@ -143,25 +144,27 @@ impl<S: TreeState> Tree<S> {
/// Each rollout derives its RNG seed from `base_seed + rollout_idx` so output is
/// reproducible regardless of thread scheduling. The tree is expanded once before
/// dispatch; all rollouts evaluate the same target node.
pub fn simulate_parallel(&mut self, n_rollouts: usize, base_seed: u64) {
///
/// `rollout_fn(state, rng) -> reward` is called per thread. Pass
/// `Tree::default_rollout` for the stub (0.5), or a real state-walk closure
/// built from `McSnapshot::step` for live game evaluation.
pub fn simulate_parallel<F>(&mut self, n_rollouts: usize, base_seed: u64, rollout_fn: F)
where
S: Sync,
F: Fn(&S, &mut XorShift64) -> f32 + Sync,
{
if n_rollouts == 0 {
return;
}
let leaf = self.select(0);
let target = self.expand(leaf).unwrap_or(leaf);
// rollout_fn mirrors Tree::simulate but is free of &self so rayon can call it.
// When mc-turn lands, replace this closure body with a real state-walk loop.
let rollout_fn = |idx: usize, rng: &mut XorShift64| -> f32 {
let _ = (idx, rng); // future rollout args — consumed when simulation is real
0.5
};
let state = &self.nodes[target].state;
let mut rewards: Vec<(usize, f32)> = (0..n_rollouts)
.into_par_iter()
.map(|i| {
let mut rng = XorShift64::new(base_seed + i as u64);
let reward = rollout_fn(target, &mut rng);
let reward = rollout_fn(state, &mut rng);
(i, reward)
})
.collect();
@ -172,4 +175,41 @@ impl<S: TreeState> Tree<S> {
self.backpropagate(target, reward);
}
}
/// Stub rollout — returns 0.5 regardless of state.
/// Use as `rollout_fn` argument before real simulation is wired.
pub fn default_rollout(_state: &S, _rng: &mut XorShift64) -> f32 {
0.5
}
}
// ── Rollout helpers for McSnapshot ──────────────────────────────────────────
/// Walk `snapshot` forward up to `depth` steps with random legal actions,
/// then return the heuristic value for `player_index`.
///
/// Called from `api-gdext` and `McTreeController`; lives in this module so
/// callers import from a single crate path.
///
/// The function is generic via the `TreeState` blanket: callers pass an
/// `&McSnapshot` directly (mc-turn depends on mc-ai, not the reverse).
/// The `TreeState` trait is NOT implemented here to avoid the circular dep —
/// callers do the state-walk with `McSnapshot::step` directly.
pub fn rollout_snapshot<S, FStep, FScore>(
state: &S,
rng: &mut XorShift64,
depth: u32,
step_fn: &FStep,
score_fn: &FScore,
) -> f32
where
S: Clone,
FStep: Fn(&S, u32, &mut XorShift64) -> S,
FScore: Fn(&S) -> f32,
{
let mut s = state.clone();
for d in 0..depth {
s = step_fn(&s, d, rng);
}
score_fn(&s)
}

View file

@ -80,7 +80,7 @@ fn terminal_state_has_no_legal_actions() {
#[test]
fn simulate_parallel_visits_match_n_rollouts() {
let mut tree = Tree::new(ToyState { depth: 3, branching: 4 });
tree.simulate_parallel(200, 42);
tree.simulate_parallel(200, 42, Tree::<ToyState>::default_rollout);
// All 200 rollouts must be backpropagated: root visits == 200.
assert_eq!(tree.root().visits, 200);
// Stubbed rollout returns 0.5 → wins == 0.5 * visits.
@ -91,7 +91,7 @@ fn simulate_parallel_visits_match_n_rollouts() {
fn simulate_parallel_deterministic_across_runs() {
let make = || {
let mut tree = Tree::new(ToyState { depth: 3, branching: 4 });
tree.simulate_parallel(500, 99);
tree.simulate_parallel(500, 99, Tree::<ToyState>::default_rollout);
(tree.root().visits, tree.nodes.len())
};
let (v1, n1) = make();
@ -118,7 +118,7 @@ fn parallel_faster_than_serial_for_large_n() {
let parallel_ms = {
let mut tree = Tree::new(ToyState { depth: 3, branching: 4 });
let t = Instant::now();
tree.simulate_parallel(N, 1);
tree.simulate_parallel(N, 1, Tree::<ToyState>::default_rollout);
t.elapsed().as_millis()
};