#!/usr/bin/env python3 """determinism-compare.py — Per-seed parity check for CPU↔GPU AI batches. Companion to `tools/determinism-audit.sh` Scenario 2. Reads two batch output directories (each containing `game__seed/turn_stats.jsonl` dirs), matches seeds, then: - Integer fields (all non-float values in turn_stats schema) MUST match byte-for-byte across turns. - Float fields MUST match within `--float-tol` (default 1e-4 absolute). Exits: 0 all seeds parity-green 1 at least one seed diverged (integer mismatch OR float beyond tolerance) 2 usage / I/O error stdlib only — no pip installs. Mirrors the stdlib-only policy of autoplay-report.py. """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any INT_FIELDS_REQUIRED_EQUAL: set[str] = { # Top-level of each turn_stats line "turn", "winner_index", # aggregate.* "total_combats", "total_cities_founded", "total_cities_captured", "turn_first_combat", "turn_first_city_captured", # player_stats..* "pop", "pop_peak", "mil", "cities", "cities_captured", "cities_lost", "gold", "gold_peak", "techs", "tiles", "buildings", "happiness", "food_total", "production_total", "kills", "units_lost", "turn_first_pop_3", "turn_first_pop_4", } # Strings must match exactly (no tolerance makes sense) STR_FIELDS_REQUIRED_EQUAL: set[str] = {"outcome", "victory_type", "winner_personality"} # Fields that legitimately vary across runs (wall-clock timings, boot stamps, etc). # Not a determinism signal — excluded from the parity check. EXCLUDE_FIELDS: set[str] = { "wall_clock_sec", # per-turn wall-clock time "start_stamp", # ISO-8601 boot timestamp in meta.json "finished_at", # audit completion timestamp } def find_seed_dirs(root: Path) -> dict[int, Path]: """Return {seed: most-recent-dir} for `game__seed` dirs under root.""" by_seed: dict[int, list[Path]] = {} for d in root.iterdir(): if not d.is_dir() or not d.name.startswith("game_"): continue parts = d.name.rsplit("_seed", 1) if len(parts) != 2 or not parts[1].isdigit(): continue by_seed.setdefault(int(parts[1]), []).append(d) return {seed: sorted(dirs)[-1] for seed, dirs in by_seed.items()} def load_jsonl(path: Path) -> list[dict[str, Any]]: if not path.exists(): return [] out: list[dict[str, Any]] = [] for raw in path.read_text().splitlines(): raw = raw.strip() if not raw: continue try: out.append(json.loads(raw)) except json.JSONDecodeError: pass return out def walk_compare( a: Any, b: Any, path: str, tol: float, diffs: list[str] ) -> None: """Recursively compare two JSON-ish values. Append human-readable diffs.""" if type(a) is not type(b): # int vs float is permitted only if the field is not in the required-int set. if isinstance(a, (int, float)) and isinstance(b, (int, float)): pass # fall through to numeric compare else: diffs.append(f"{path}: type mismatch {type(a).__name__} vs {type(b).__name__}") return if isinstance(a, dict): keys = set(a) | set(b) for k in sorted(keys): if k in EXCLUDE_FIELDS: continue if k not in a: diffs.append(f"{path}.{k}: missing on cpu side") continue if k not in b: diffs.append(f"{path}.{k}: missing on gpu side") continue walk_compare(a[k], b[k], f"{path}.{k}" if path else k, tol, diffs) return if isinstance(a, list): if len(a) != len(b): diffs.append(f"{path}: list length {len(a)} vs {len(b)}") return for i, (x, y) in enumerate(zip(a, b)): walk_compare(x, y, f"{path}[{i}]", tol, diffs) return # Leaf: compare values. leaf_name = path.rsplit(".", 1)[-1].rsplit("[", 1)[0] if isinstance(a, str) and isinstance(b, str): if a != b: diffs.append(f"{path}: string '{a}' != '{b}'") return if isinstance(a, (int, float)) and isinstance(b, (int, float)): # Integer-required field: byte-equal. if leaf_name in INT_FIELDS_REQUIRED_EQUAL: if a != b: diffs.append(f"{path}: int-required {a} != {b} (must be byte-equal)") return # Otherwise treat as float with tolerance. if abs(float(a) - float(b)) > tol: diffs.append(f"{path}: float {a} != {b} (tol={tol})") return if a != b: diffs.append(f"{path}: {a!r} != {b!r}") def compare_seed( dir_a: Path, dir_b: Path, seed: int, tol: float ) -> list[str]: ts_a = load_jsonl(dir_a / "turn_stats.jsonl") ts_b = load_jsonl(dir_b / "turn_stats.jsonl") if len(ts_a) != len(ts_b): return [f"seed {seed}: turn count differs ({len(ts_a)} vs {len(ts_b)})"] diffs: list[str] = [] for i, (line_a, line_b) in enumerate(zip(ts_a, ts_b)): walk_compare(line_a, line_b, f"seed{seed}.line{i}", tol, diffs) return diffs def main(argv: list[str]) -> int: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("dir_a", type=Path, help="First batch dir (e.g. CPU run)") ap.add_argument("dir_b", type=Path, help="Second batch dir (e.g. GPU run)") ap.add_argument("--float-tol", type=float, default=1e-4, help="Absolute tolerance for float fields (default 1e-4)") ap.add_argument("--max-diff-lines", type=int, default=20, help="Max diff lines to print per seed before truncating") args = ap.parse_args(argv[1:]) if not args.dir_a.is_dir(): print(f"ERROR: {args.dir_a} is not a directory", file=sys.stderr) return 2 if not args.dir_b.is_dir(): print(f"ERROR: {args.dir_b} is not a directory", file=sys.stderr) return 2 seeds_a = find_seed_dirs(args.dir_a) seeds_b = find_seed_dirs(args.dir_b) common = sorted(set(seeds_a) & set(seeds_b)) only_a = sorted(set(seeds_a) - set(seeds_b)) only_b = sorted(set(seeds_b) - set(seeds_a)) if not common: print(f"ERROR: no overlapping seeds between {args.dir_a} and {args.dir_b}", file=sys.stderr) return 2 if only_a: print(f"WARN: seeds only in {args.dir_a.name}: {only_a}", file=sys.stderr) if only_b: print(f"WARN: seeds only in {args.dir_b.name}: {only_b}", file=sys.stderr) total_diffs = 0 failing_seeds: list[int] = [] for seed in common: diffs = compare_seed(seeds_a[seed], seeds_b[seed], seed, args.float_tol) if diffs: failing_seeds.append(seed) total_diffs += len(diffs) print(f"seed {seed}: {len(diffs)} divergence(s)") for line in diffs[: args.max_diff_lines]: print(f" {line}") if len(diffs) > args.max_diff_lines: print(f" ... ({len(diffs) - args.max_diff_lines} more)") else: print(f"seed {seed}: OK") print(f"\n=== summary ===") print(f"seeds compared: {len(common)}") print(f"seeds passing: {len(common) - len(failing_seeds)}") print(f"seeds failing: {len(failing_seeds)} ({failing_seeds})") print(f"total divergences: {total_diffs}") print(f"float tolerance: {args.float_tol}") return 0 if not failing_seeds else 1 if __name__ == "__main__": sys.exit(main(sys.argv))