feat(mc-ai): ✨ Introduce strategic axis evaluation in game_state and enhance MCTS tree logic for AI decision-making
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
a2784dc0b1
commit
163eb0634d
4 changed files with 103 additions and 14 deletions
|
|
@ -2,6 +2,54 @@ use std::collections::HashMap;
|
|||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Fixed-index strategic axis identifiers. Used to key the flat `[u8; 8]` form
|
||||
/// of `strategic_axes` for GPU upload without String heap allocation.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum AxisId {
|
||||
Expansion = 0,
|
||||
Production = 1,
|
||||
Wealth = 2,
|
||||
Culture = 3,
|
||||
// Slots 4-7 reserved for future axes (magic, military, diplomacy, science).
|
||||
}
|
||||
|
||||
impl AxisId {
|
||||
pub const COUNT: usize = 8;
|
||||
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Expansion => "expansion",
|
||||
Self::Production => "production",
|
||||
Self::Wealth => "wealth",
|
||||
Self::Culture => "culture",
|
||||
}
|
||||
}
|
||||
|
||||
const ALL_NAMED: &'static [AxisId] =
|
||||
&[AxisId::Expansion, AxisId::Production, AxisId::Wealth, AxisId::Culture];
|
||||
}
|
||||
|
||||
/// Encode a `HashMap<String, u8>` strategic-axes map to a fixed-size array
|
||||
/// keyed by `AxisId` discriminant. Unknown keys are ignored; missing keys
|
||||
/// default to 0. Slots 4-7 are always 0 (reserved for future axes).
|
||||
pub fn axes_to_flat(axes: &HashMap<String, u8>) -> [u8; 8] {
|
||||
let mut out = [0u8; 8];
|
||||
for &id in AxisId::ALL_NAMED {
|
||||
out[id as usize] = *axes.get(id.as_str()).unwrap_or(&0);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Decode a flat `[u8; 8]` back into a `HashMap<String, u8>` containing only
|
||||
/// the named axes (slots 4-7 are ignored).
|
||||
pub fn flat_to_axes(flat: &[u8; 8]) -> HashMap<String, u8> {
|
||||
AxisId::ALL_NAMED
|
||||
.iter()
|
||||
.map(|&id| (id.as_str().to_string(), flat[id as usize]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compressed AI-relevant player state — serializable for MCTS evaluation.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct AiPlayerState {
|
||||
|
|
|
|||
|
|
@ -12,5 +12,6 @@ pub mod mcts_tree;
|
|||
|
||||
pub use evaluator::ScoringWeights;
|
||||
pub use game_state::{
|
||||
AiCityState, AiPlayerState, AiProductionCandidate, AiTechCandidate, StrategicWeights,
|
||||
axes_to_flat, flat_to_axes, AiCityState, AiPlayerState, AiProductionCandidate,
|
||||
AiTechCandidate, AxisId, StrategicWeights,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -129,7 +129,8 @@ impl<S: TreeState> Tree<S> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Run one full MCTS iteration (select → expand → simulate → backpropagate).
|
||||
/// Run one full MCTS iteration with the stub rollout (select → expand → simulate → backpropagate).
|
||||
/// For real-game rollouts use `simulate_parallel` with a custom `rollout_fn`.
|
||||
pub fn iterate(&mut self, rng: &mut XorShift64) {
|
||||
let leaf = self.select(0);
|
||||
let target = self.expand(leaf).unwrap_or(leaf);
|
||||
|
|
@ -143,25 +144,27 @@ impl<S: TreeState> Tree<S> {
|
|||
/// Each rollout derives its RNG seed from `base_seed + rollout_idx` so output is
|
||||
/// reproducible regardless of thread scheduling. The tree is expanded once before
|
||||
/// dispatch; all rollouts evaluate the same target node.
|
||||
pub fn simulate_parallel(&mut self, n_rollouts: usize, base_seed: u64) {
|
||||
///
|
||||
/// `rollout_fn(state, rng) -> reward` is called per thread. Pass
|
||||
/// `Tree::default_rollout` for the stub (0.5), or a real state-walk closure
|
||||
/// built from `McSnapshot::step` for live game evaluation.
|
||||
pub fn simulate_parallel<F>(&mut self, n_rollouts: usize, base_seed: u64, rollout_fn: F)
|
||||
where
|
||||
S: Sync,
|
||||
F: Fn(&S, &mut XorShift64) -> f32 + Sync,
|
||||
{
|
||||
if n_rollouts == 0 {
|
||||
return;
|
||||
}
|
||||
let leaf = self.select(0);
|
||||
let target = self.expand(leaf).unwrap_or(leaf);
|
||||
|
||||
// rollout_fn mirrors Tree::simulate but is free of &self so rayon can call it.
|
||||
// When mc-turn lands, replace this closure body with a real state-walk loop.
|
||||
let rollout_fn = |idx: usize, rng: &mut XorShift64| -> f32 {
|
||||
let _ = (idx, rng); // future rollout args — consumed when simulation is real
|
||||
0.5
|
||||
};
|
||||
let state = &self.nodes[target].state;
|
||||
|
||||
let mut rewards: Vec<(usize, f32)> = (0..n_rollouts)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let mut rng = XorShift64::new(base_seed + i as u64);
|
||||
let reward = rollout_fn(target, &mut rng);
|
||||
let reward = rollout_fn(state, &mut rng);
|
||||
(i, reward)
|
||||
})
|
||||
.collect();
|
||||
|
|
@ -172,4 +175,41 @@ impl<S: TreeState> Tree<S> {
|
|||
self.backpropagate(target, reward);
|
||||
}
|
||||
}
|
||||
|
||||
/// Stub rollout — returns 0.5 regardless of state.
|
||||
/// Use as `rollout_fn` argument before real simulation is wired.
|
||||
pub fn default_rollout(_state: &S, _rng: &mut XorShift64) -> f32 {
|
||||
0.5
|
||||
}
|
||||
}
|
||||
|
||||
// ── Rollout helpers for McSnapshot ──────────────────────────────────────────
|
||||
|
||||
/// Walk `snapshot` forward up to `depth` steps with random legal actions,
|
||||
/// then return the heuristic value for `player_index`.
|
||||
///
|
||||
/// Called from `api-gdext` and `McTreeController`; lives in this module so
|
||||
/// callers import from a single crate path.
|
||||
///
|
||||
/// The function is generic via the `TreeState` blanket: callers pass an
|
||||
/// `&McSnapshot` directly (mc-turn depends on mc-ai, not the reverse).
|
||||
/// The `TreeState` trait is NOT implemented here to avoid the circular dep —
|
||||
/// callers do the state-walk with `McSnapshot::step` directly.
|
||||
pub fn rollout_snapshot<S, FStep, FScore>(
|
||||
state: &S,
|
||||
rng: &mut XorShift64,
|
||||
depth: u32,
|
||||
step_fn: &FStep,
|
||||
score_fn: &FScore,
|
||||
) -> f32
|
||||
where
|
||||
S: Clone,
|
||||
FStep: Fn(&S, u32, &mut XorShift64) -> S,
|
||||
FScore: Fn(&S) -> f32,
|
||||
{
|
||||
let mut s = state.clone();
|
||||
for d in 0..depth {
|
||||
s = step_fn(&s, d, rng);
|
||||
}
|
||||
score_fn(&s)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ fn terminal_state_has_no_legal_actions() {
|
|||
#[test]
|
||||
fn simulate_parallel_visits_match_n_rollouts() {
|
||||
let mut tree = Tree::new(ToyState { depth: 3, branching: 4 });
|
||||
tree.simulate_parallel(200, 42);
|
||||
tree.simulate_parallel(200, 42, Tree::<ToyState>::default_rollout);
|
||||
// All 200 rollouts must be backpropagated: root visits == 200.
|
||||
assert_eq!(tree.root().visits, 200);
|
||||
// Stubbed rollout returns 0.5 → wins == 0.5 * visits.
|
||||
|
|
@ -91,7 +91,7 @@ fn simulate_parallel_visits_match_n_rollouts() {
|
|||
fn simulate_parallel_deterministic_across_runs() {
|
||||
let make = || {
|
||||
let mut tree = Tree::new(ToyState { depth: 3, branching: 4 });
|
||||
tree.simulate_parallel(500, 99);
|
||||
tree.simulate_parallel(500, 99, Tree::<ToyState>::default_rollout);
|
||||
(tree.root().visits, tree.nodes.len())
|
||||
};
|
||||
let (v1, n1) = make();
|
||||
|
|
@ -118,7 +118,7 @@ fn parallel_faster_than_serial_for_large_n() {
|
|||
let parallel_ms = {
|
||||
let mut tree = Tree::new(ToyState { depth: 3, branching: 4 });
|
||||
let t = Instant::now();
|
||||
tree.simulate_parallel(N, 1);
|
||||
tree.simulate_parallel(N, 1, Tree::<ToyState>::default_rollout);
|
||||
t.elapsed().as_millis()
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue