magicciv/tools/sprite-generation/engine/calibration.py

247 lines
8 KiB
Python

"""Calibration tool: compare local SigLIP2 scores against Sonnet answer key.
Scores all labeled images with the local scorer, then generates a comparison
report showing per-dimension correlation, systematic biases, and disagreements.
Usage:
python3 -m engine.calibration [--device cuda:0] [--limit 20]
"""
from __future__ import annotations
import json
import logging
import sqlite3
import statistics
from pathlib import Path
from engine.local_scorer import LocalSpriteScorer
logger = logging.getLogger(__name__)
TOOL_DIR = Path(__file__).resolve().parent.parent
DB_PATH = TOOL_DIR / "spritegen.db"
UNIT_DIMS = (
"camera_angle", "facing_direction", "composition", "subject_type",
"race_accuracy", "gender_accuracy", "equipment_accuracy",
"pose_quality", "background_compliance", "art_style",
)
def _parse_entity_metadata(entity_id: str) -> dict[str, str]:
"""Extract race and gender from entity_id like 'bowmen_dwarves_f'."""
parts = entity_id
meta: dict[str, str] = {}
if parts.endswith("_m"):
meta["gender"] = "male"
parts = parts[:-2]
elif parts.endswith("_f"):
meta["gender"] = "female"
parts = parts[:-2]
for race in ("dwarves", "humans", "high_elves", "orcs"):
suffix = f"_{race}"
if parts.endswith(suffix):
meta["race"] = race
parts = parts[:-len(suffix)]
break
meta["base_unit"] = parts
return meta
def _pearson_r(xs: list[float], ys: list[float]) -> float:
"""Pearson correlation coefficient between two lists."""
n = len(xs)
if n < 3:
return 0.0
mean_x = sum(xs) / n
mean_y = sum(ys) / n
dx = [x - mean_x for x in xs]
dy = [y - mean_y for y in ys]
num = sum(a * b for a, b in zip(dx, dy))
den_x = sum(a * a for a in dx) ** 0.5
den_y = sum(a * a for a in dy) ** 0.5
if den_x == 0 or den_y == 0:
return 0.0
return num / (den_x * den_y)
def load_sonnet_scores(limit: int | None = None) -> list[dict]:
"""Load all 10-dim Sonnet-scored unit variants from sprites.db."""
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
query = """
SELECT v.id as variant_id, v.raw_path, v.notes, s.entity_id, s.prompt, s.category
FROM variants v
JOIN sprites s ON v.sprite_id = s.id
WHERE v.notes IS NOT NULL
AND v.rating IS NOT NULL AND v.rating != -1
AND v.raw_path IS NOT NULL
"""
if limit:
query += f" LIMIT {limit}"
rows = conn.execute(query).fetchall()
results = []
for r in rows:
path = Path(r["raw_path"])
if not path.exists():
continue
sonnet_scores = json.loads(r["notes"])
meta = _parse_entity_metadata(r["entity_id"])
results.append({
"variant_id": r["variant_id"],
"raw_path": str(path),
"entity_id": r["entity_id"],
"prompt": r["prompt"],
"category": r["category"],
"sonnet_scores": sonnet_scores,
"sonnet_dims": tuple(sorted(sonnet_scores.keys())),
"race": meta.get("race", ""),
"gender": meta.get("gender", ""),
"base_unit": meta.get("base_unit", ""),
})
conn.close()
return results
def run_calibration(device: str = "cuda:0", limit: int | None = None) -> dict:
"""Score all labeled images with local scorer and compare to Sonnet.
Returns a calibration report dict with per-dimension stats.
"""
labeled = load_sonnet_scores(limit=limit)
if not labeled:
print("No labeled data found.")
return {}
print(f"Loaded {len(labeled)} Sonnet-scored variants with images on disk.")
scorer = LocalSpriteScorer(device=device)
scorer.load_sync(device=device)
# Collect all unique Sonnet dimensions across the dataset
all_sonnet_dims: set[str] = set()
for item in labeled:
all_sonnet_dims.update(item["sonnet_scores"].keys())
all_dims = sorted(all_sonnet_dims)
per_dim_sonnet: dict[str, list[float]] = {d: [] for d in all_dims}
per_dim_local: dict[str, list[float]] = {d: [] for d in all_dims}
disagreements: list[dict] = []
# Map old 4-dim names to our 10-dim local scorer queries
# Old dims: perspective, composition, subject_accuracy, production_quality
# These map approximately to our dimensions
OLD_TO_LOCAL_MAP = {
"perspective": "camera_angle",
"composition": "composition",
"subject_accuracy": "subject_type",
"production_quality": "art_style",
}
for i, item in enumerate(labeled):
is_unit = item["category"] == "units"
if is_unit:
result = scorer.score_image(
image_path=item["raw_path"],
entity_id=item["entity_id"],
race=item["race"],
gender=item["gender"],
entity_description=item["base_unit"].replace("_", " "),
)
local_scores = result.scores
else:
# Non-unit categories — score with generic queries
result = scorer.score_image(
image_path=item["raw_path"],
entity_id=item["entity_id"],
)
local_scores = result.scores
sonnet_scores = item["sonnet_scores"]
for dim in sonnet_scores:
s_val = sonnet_scores[dim]
# Find corresponding local dimension
local_dim = OLD_TO_LOCAL_MAP.get(dim, dim)
l_val = local_scores.get(local_dim, 0.5)
per_dim_sonnet[dim].append(s_val)
per_dim_local[dim].append(l_val)
diff = abs(s_val - l_val)
if diff > 0.3:
disagreements.append({
"variant_id": item["variant_id"],
"entity_id": item["entity_id"],
"dimension": dim,
"local_dimension": local_dim,
"sonnet": s_val,
"local": l_val,
"diff": diff,
})
if (i + 1) % 10 == 0:
print(f" Scored {i + 1}/{len(labeled)}...")
# Compute per-dimension statistics
report: dict[str, dict] = {}
print(f"\n{'='*80}")
print(f"CALIBRATION REPORT — {len(labeled)} variants")
print(f"{'='*80}")
print(f"\n{'Dimension':<25} {'Pearson r':>10} {'Sonnet avg':>12} {'Local avg':>12} {'Bias':>8}")
print("-" * 70)
for dim in all_dims:
s_vals = per_dim_sonnet[dim]
l_vals = per_dim_local[dim]
r = _pearson_r(s_vals, l_vals)
s_avg = statistics.mean(s_vals)
l_avg = statistics.mean(l_vals)
bias = l_avg - s_avg
report[dim] = {
"pearson_r": round(r, 3),
"sonnet_avg": round(s_avg, 3),
"local_avg": round(l_avg, 3),
"bias": round(bias, 3),
}
r_indicator = "" if r > 0.7 else "~" if r > 0.4 else ""
print(f"{r_indicator} {dim:<23} {r:>10.3f} {s_avg:>12.3f} {l_avg:>12.3f} {bias:>+8.3f}")
# Top disagreements
disagreements.sort(key=lambda d: d["diff"], reverse=True)
print(f"\n{'='*80}")
print(f"TOP DISAGREEMENTS (|diff| > 0.3) — {len(disagreements)} total")
print(f"{'='*80}")
for d in disagreements[:20]:
print(
f" #{d['variant_id']} {d['entity_id']:30s} {d['dimension']:25s} "
f"sonnet={d['sonnet']:.2f} local={d['local']:.2f} diff={d['diff']:+.2f}"
)
return {
"n_variants": len(labeled),
"per_dimension": report,
"n_disagreements": len(disagreements),
"top_disagreements": disagreements[:20],
}
if __name__ == "__main__":
import argparse
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Calibrate local scorer against Sonnet")
parser.add_argument("--device", default="cuda:0", help="GPU device (default: cuda:0)")
parser.add_argument("--limit", type=int, help="Max variants to score (for quick testing)")
args = parser.parse_args()
run_calibration(device=args.device, limit=args.limit)