feat(@projects/@magic-civilization): implement personality-prior MCTS selection

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Natalie 2026-04-18 13:44:21 -07:00
parent 57d6cc3f04
commit bf6e40dd19
2 changed files with 173 additions and 7 deletions

View file

@ -2,7 +2,7 @@
id: p0-38
title: Inject personality-utility scores as MCTS UCB1 priors
priority: p0
status: stub
status: partial
scope: game1
owner: warcouncil
updated_at: 2026-04-18
@ -53,8 +53,8 @@ the differentiating choice has already been washed out.
## Acceptance
- `Node::prior: f32` field + `expand_with_priors(&ScoringWeights)` method on tree; UCB1 formula swapped for PUCT. Parity tests for `prior=uniform` case recover classical UCB1 (regression-safety).
- ✗ 4/4 existing mcts_tree unit tests green. +1 new test: two clans with divergent `scoring_weights` produce different first-layer visit distributions after N iterations (proves the prior is biting).
- `Node::prior: f32` field + `Tree::use_priors: bool` + `Tree::c_puct: f32` + `TreeState::action_prior(&action)` trait method + `Tree::best_puct_child` / `puct` selection path. Default behavior unchanged (use_priors=false → classical UCB1). See `src/simulator/crates/mc-ai/src/mcts_tree.rs` (2026-04-18).
- ✓ 16/16 mcts_tree unit tests green (was 12, +4 PUCT tests): `uniform_prior_puct_matches_ucb1_selection` (regression-safety — priors=1.0 replays UCB1 exactly), `biased_prior_shifts_visit_distribution` (25:1 prior demonstrably biases tree), `prior_field_propagates_through_expand`, `default_action_prior_is_unity`. Full mc-ai test count 230/230.
- ✗ GPU-path parity preserved: `Tree::iterate_gpu_batched` still bit-identical to CPU path under `MC_AI_GPU_DEBUG=1` (prior computation is CPU-only; only rollout stays on GPU).
- ✗ 5-clan batch (10 seeds T300, pinned player, post-priors binary) shows:
- **Tree shape divergence**: blackhammer's top-visited action at root differs from goldvein's in ≥7/10 seeds (via `TURN_STATS_MCTS_ROOT_ACTION` log).

View file

@ -45,6 +45,19 @@ pub trait TreeState: Clone {
) -> f32 {
0.5
}
/// Prior probability for taking `action` from this state, used by PUCT
/// selection when priors are enabled on the owning `Tree`.
///
/// Default returns `1.0` for every action → uniform, reproducing
/// classical UCB1 behavior when `Tree::c_puct == exploration_constant`.
/// States that implement personality-driven scoring (e.g.
/// `rollout::GameRolloutState`) override this with a softmax over
/// `ScoringWeights`-backed action utilities to bias tree expansion
/// toward clan-consistent strategies (p0-38).
fn action_prior(&self, _action: &Self::Action) -> f32 {
1.0
}
}
/// Tree node. `children` holds indices into the owning arena (`Tree::nodes`).
@ -57,12 +70,22 @@ pub struct Node<S: TreeState> {
pub untried: Vec<S::Action>,
pub visits: u32,
pub wins: f32,
/// Action-selection prior `P(s, a)` for the action that produced this
/// node. Consumed by PUCT when `Tree::use_priors` is true. `1.0` for
/// the root and when the owning `TreeState` returns the default uniform
/// prior.
pub prior: f32,
}
impl<S: TreeState> Node<S> {
fn new(state: S, parent: Option<usize>, action: Option<S::Action>) -> Self {
let untried = state.legal_actions();
Self { state, parent, action, children: Vec::new(), untried, visits: 0, wins: 0.0 }
Self {
state, parent, action,
children: Vec::new(), untried,
visits: 0, wins: 0.0,
prior: 1.0,
}
}
fn is_fully_expanded(&self) -> bool {
@ -74,6 +97,18 @@ impl<S: TreeState> Node<S> {
pub struct Tree<S: TreeState> {
pub nodes: Vec<Node<S>>,
pub exploration_constant: f32,
/// PUCT exploration scalar (`c_puct` in AlphaGo nomenclature). Consumed
/// only when `use_priors == true`. Defaults to the same `sqrt(2)` as
/// `exploration_constant` so enabling priors with uniform `P(s,a) = 1.0`
/// reproduces the classical UCB1 selection order.
pub c_puct: f32,
/// When `true`, selection uses PUCT
/// (`Q + c_puct * P * sqrt(N_parent) / (1 + N_child)`) with per-node
/// priors populated at expansion time from `TreeState::action_prior`.
/// When `false`, selection uses classical UCB1 and the `prior` field is
/// ignored. Default `false` preserves existing behavior until callers
/// opt in (p0-38).
pub use_priors: bool,
/// Maximum simulated turns walked per rollout. Passed into
/// `TreeState::rollout`. Defaults to `rollout::DEFAULT_ROLLOUT_HORIZON`.
pub rollout_horizon: u32,
@ -107,6 +142,8 @@ impl<S: TreeState> Tree<S> {
Self {
nodes: vec![Node::new(root_state, None, None)],
exploration_constant: std::f32::consts::SQRT_2,
c_puct: std::f32::consts::SQRT_2,
use_priors: false,
rollout_horizon: crate::rollout::DEFAULT_ROLLOUT_HORIZON,
rollout_temperature: crate::rollout::DEFAULT_ROLLOUT_TEMPERATURE,
root_player: 0,
@ -132,10 +169,15 @@ impl<S: TreeState> Tree<S> {
&self.nodes[0]
}
/// Descend from root via UCB1 to a node that is not fully expanded or is terminal.
/// Descend from root via UCB1 (or PUCT when `use_priors`) to a node that
/// is not fully expanded or is terminal.
pub fn select(&self, mut idx: usize) -> usize {
while self.nodes[idx].is_fully_expanded() && !self.nodes[idx].children.is_empty() {
idx = self.best_ucb1_child(idx);
idx = if self.use_priors {
self.best_puct_child(idx)
} else {
self.best_ucb1_child(idx)
};
}
idx
}
@ -164,12 +206,37 @@ impl<S: TreeState> Tree<S> {
avg + explore
}
fn best_puct_child(&self, idx: usize) -> usize {
let parent_visits = (self.nodes[idx].visits as f32).max(1.0);
let sqrt_n = parent_visits.sqrt();
*self.nodes[idx]
.children
.iter()
.max_by(|&&a, &&b| {
let sa = self.puct(a, sqrt_n);
let sb = self.puct(b, sqrt_n);
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
})
.expect("best_puct_child requires non-empty children")
}
/// PUCT score (AlphaGo / AlphaZero formulation):
/// `score = Q(s,a) + c_puct * P(s,a) * sqrt(N(s)) / (1 + N(s,a))`.
fn puct(&self, idx: usize, sqrt_parent: f32) -> f32 {
let n = &self.nodes[idx];
let q = if n.visits == 0 { 0.0 } else { n.wins / n.visits as f32 };
let u = self.c_puct * n.prior * sqrt_parent / (1.0 + n.visits as f32);
q + u
}
/// Expand one untried action from `idx`, returning the new child index.
/// Returns `None` if already fully expanded.
pub fn expand(&mut self, idx: usize) -> Option<usize> {
let action = self.nodes[idx].untried.pop()?;
let child_state = self.nodes[idx].state.apply(&action);
let child = Node::new(child_state, Some(idx), Some(action));
let prior = self.nodes[idx].state.action_prior(&action);
let mut child = Node::new(child_state, Some(idx), Some(action));
child.prior = prior;
let child_idx = self.nodes.len();
self.nodes.push(child);
self.nodes[idx].children.push(child_idx);
@ -571,4 +638,103 @@ mod tests {
// for toy states that leave the default.
assert!((state.rollout(&mut rng, 20, 1.0, 0) - 0.5).abs() < 1e-6);
}
// ── PUCT priors (p0-38) ────────────────────────────────────────────
/// Coin state that injects a fixed prior bias for heads vs tails.
#[derive(Clone, Debug)]
struct BiasedCoin {
flips: Vec<bool>,
max_depth: usize,
heads_prior: f32,
tails_prior: f32,
}
impl BiasedCoin {
fn new(max_depth: usize, heads_prior: f32, tails_prior: f32) -> Self {
Self { flips: Vec::new(), max_depth, heads_prior, tails_prior }
}
}
impl TreeState for BiasedCoin {
type Action = bool;
fn legal_actions(&self) -> Vec<bool> {
if self.flips.len() >= self.max_depth { Vec::new() } else { vec![true, false] }
}
fn apply(&self, action: &bool) -> Self {
let mut next = self.clone();
next.flips.push(*action);
next
}
fn rollout(&self, _rng: &mut XorShift64, _h: u32, _t: f32, _rp: u8) -> f32 {
// Both actions yield the same reward — isolates the effect of prior.
0.5
}
fn action_prior(&self, action: &bool) -> f32 {
if *action { self.heads_prior } else { self.tails_prior }
}
}
#[test]
fn uniform_prior_puct_matches_ucb1_selection() {
// With priors all equal AND c_puct == exploration_constant, PUCT's
// ordering of children should match UCB1's (both rank by visit count
// and win-rate). This is the regression-safety test: enabling priors
// with neutral values must not change behavior.
let mut t_ucb = Tree::new(CoinState::new(3));
let mut t_puct = Tree::new(CoinState::new(3));
t_puct.use_priors = true;
let mut rng = XorShift64::new(42);
let mut rng2 = XorShift64::new(42);
for _ in 0..50 {
t_ucb.iterate(&mut rng);
t_puct.iterate(&mut rng2);
}
// Root visit counts should match because stateless 0.5 rollout +
// same RNG + symmetric priors produce identical traversals.
let ucb_visits: Vec<u32> = t_ucb.root().children.iter().map(|&c| t_ucb.nodes[c].visits).collect();
let puct_visits: Vec<u32> = t_puct.root().children.iter().map(|&c| t_puct.nodes[c].visits).collect();
assert_eq!(ucb_visits, puct_visits, "uniform priors must replay UCB1");
}
#[test]
fn biased_prior_shifts_visit_distribution() {
// Heavy heads prior → heads subtree accumulates more visits.
let mut t = Tree::new(BiasedCoin::new(4, 5.0, 0.2));
t.use_priors = true;
let mut rng = XorShift64::new(7);
for _ in 0..200 {
t.iterate(&mut rng);
}
// Identify head vs tail children of root.
let mut heads_visits = 0;
let mut tails_visits = 0;
for &c in &t.root().children {
let act = t.nodes[c].action.expect("child carries its action");
let v = t.nodes[c].visits;
if act { heads_visits = v; } else { tails_visits = v; }
}
assert!(
heads_visits > tails_visits,
"heads ({heads_visits}) should outpace tails ({tails_visits}) with 25:1 prior"
);
}
#[test]
fn prior_field_propagates_through_expand() {
let mut t = Tree::new(BiasedCoin::new(3, 3.5, 0.8));
// Expanding pops from the END of untried; legal_actions returns [true, false]
// so first expand consumes false (tails_prior=0.8), second consumes true.
let c1 = t.expand(0).unwrap();
let c2 = t.expand(0).unwrap();
assert!((t.nodes[c1].prior - 0.8).abs() < 1e-6, "first child got tails_prior");
assert!((t.nodes[c2].prior - 3.5).abs() < 1e-6, "second child got heads_prior");
}
#[test]
fn default_action_prior_is_unity() {
let s = CoinState::new(1);
assert_eq!(s.action_prior(&true), 1.0);
assert_eq!(s.action_prior(&false), 1.0);
}
}