magicciv/tools/sprite-generation/engine/local_scorer.py

370 lines
13 KiB
Python

"""Local sprite scoring using SigLIP2 zero-shot classification via model-boss.
Replaces Sonnet API calls with local GPU inference for 10-dimension sprite scoring.
Uses the same SigLIP2 model and approach as @imajin's imajin-semantic service,
but loaded directly with a model-boss GPU lease.
Calibration against Sonnet's scores is done via text prompt tuning — adjusting
the positive/negative text queries per dimension until local scores correlate
with Sonnet scores (Pearson r > 0.7 per dimension).
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass, field
from pathlib import Path
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
logger = logging.getLogger(__name__)
MODEL_NAME = "google/siglip2-so400m-patch14-384"
VRAM_MB = 2048
# ---------------------------------------------------------------------------
# Per-dimension text queries for zero-shot classification.
# Each dimension has positive prompts (what we want) and negative prompts
# (what we don't want). Score = max(pos_similarity) - max(neg_similarity),
# rescaled to 0-1.
# ---------------------------------------------------------------------------
DIMENSION_QUERIES: dict[str, dict[str, list[str]]] = {
"camera_angle": {
"positive": [
"isometric 3/4 elevated camera angle looking down at character",
"game sprite seen from above at 45 degree angle",
"elevated three-quarter overhead view of character",
],
"negative": [
"front-facing portrait eye-level view",
"top-down bird's eye view looking straight down",
"side view profile of character",
],
},
"facing_direction": {
"positive": [
"character facing bottom-left walking southwest",
"figure angled toward lower-left corner of image",
"rear three-quarter view character walking away to the left",
],
"negative": [
"character facing forward toward the camera",
"character facing right walking to the right",
"front view portrait looking at viewer",
],
},
"composition": {
"positive": [
"single character game sprite clean silhouette isolated",
"one figure centered in frame with clear outline",
"single game unit on plain background",
],
"negative": [
"multiple characters crowd group scene",
"cluttered scene with many objects",
"character sheet turnaround reference page multiple poses",
],
},
"subject_type": {
"positive": [], # Filled dynamically per entity
"negative": [
"abstract shape generic blob formless",
"empty background no subject",
],
},
"race_accuracy": {
"positive": [], # Filled dynamically per entity
"negative": [], # Filled dynamically per entity
},
"gender_accuracy": {
"positive": [], # Filled dynamically per entity
"negative": [], # Filled dynamically per entity
},
"equipment_accuracy": {
"positive": [], # Filled dynamically per entity
"negative": [], # Filled dynamically per entity
},
"pose_quality": {
"positive": [
"full body character visible from head to feet standing dynamic pose",
"complete character sprite feet visible action-ready stance",
"game unit full body not cropped",
],
"negative": [
"character cropped at knees or waist",
"stiff T-pose or A-pose mannequin",
"only head and shoulders portrait bust",
],
},
"background_compliance": {
"positive": [
"solid plain simple background uniform color",
"character on flat clean background easy to separate",
"simple studio background solid color no scenery",
],
"negative": [
"brown terrain ground landscape background",
"grey studio backdrop with shadows",
"scenery forest mountains sky behind character",
],
},
"art_style": {
"positive": [
"hand-painted digital fantasy game art bold colors",
"Warcraft III style painted game sprite rich saturated",
"stylized fantasy RPG character art bold shapes",
],
"negative": [
"photorealistic photograph realistic portrait",
"anime manga cartoon japanese style",
"pixel art retro 8-bit sprite",
],
},
}
# Dynamic query builders for entity-specific dimensions
RACE_QUERIES: dict[str, dict[str, list[str]]] = {
"dwarves": {
"positive": [
"short stocky dwarf character wide body small stature",
"fantasy dwarf short and broad muscular",
"dwarven proportions compact heavy build",
],
"negative": [
"tall slender human normal proportions",
"tall elf graceful thin build",
"large orc green skin tusks",
],
},
}
GENDER_QUERIES: dict[str, dict[str, list[str]]] = {
"male": {
"positive": [
"male character masculine build thick beard",
"male warrior with facial hair",
],
"negative": [
"female character feminine build",
"woman with braided hair no beard",
],
},
"female": {
"positive": [
"female character feminine build braided hair no beard",
"woman warrior sturdy feminine",
],
"negative": [
"male character with thick beard masculine",
"bearded man warrior",
],
},
}
EQUIPMENT_QUERIES: dict[str, list[str]] = {
"melee": ["warrior holding sword or axe heavy armor shield"],
"ranged": ["archer holding bow and arrows quiver crossbow"],
"cavalry": ["rider mounted on horse warhorse"],
"civilian": ["civilian carrying tools supplies no armor"],
"specialist": ["character with specialized magical or tactical equipment"],
"siege": ["siege engine catapult war machine heavy wood iron"],
}
@dataclass
class LocalScorerResult:
"""Scores from local SigLIP2 classification."""
scores: dict[str, float]
raw_similarities: dict[str, dict[str, float]] = field(default_factory=dict)
class LocalSpriteScorer:
"""Score sprite images locally using SigLIP2 zero-shot classification.
Acquires a model-boss GPU lease, loads SigLIP2, and scores images
by comparing them against dimension-specific text queries.
"""
def __init__(self, device: str = "cpu"):
self._device = device
self._model = None
self._processor = None
async def load(self) -> None:
"""Load SigLIP2 model using model-boss GPU lease."""
try:
from model_boss.client import InferenceClient
client = InferenceClient(
client_id="sprite-scorer",
auto_start_services=False,
)
lease = await client.acquire_lease(
model_id=f"service:{MODEL_NAME}",
vram_mb=VRAM_MB,
priority="normal",
)
self._device = f"cuda:{lease['gpu_index']}"
logger.info("Acquired GPU lease: %s", self._device)
except Exception as exc:
logger.warning("model-boss lease failed (%s), using %s", exc, self._device)
logger.info("Loading SigLIP2 on %s...", self._device)
self._processor = AutoProcessor.from_pretrained(MODEL_NAME)
self._model = AutoModel.from_pretrained(MODEL_NAME).to(self._device)
self._model.requires_grad_(False)
logger.info("SigLIP2 loaded")
def load_sync(self, device: str | None = None) -> None:
"""Synchronous model loading (for non-async contexts)."""
if device:
self._device = device
logger.info("Loading SigLIP2 on %s...", self._device)
self._processor = AutoProcessor.from_pretrained(MODEL_NAME)
self._model = AutoModel.from_pretrained(MODEL_NAME).to(self._device)
self._model.requires_grad_(False)
logger.info("SigLIP2 loaded")
@staticmethod
def _sigmoid(x: float, center: float = 0.25, scale: float = 10.0) -> float:
return 1.0 / (1.0 + math.exp(-scale * (x - center)))
def _compute_raw_similarity(
self,
image: Image.Image,
text_prompts: list[str],
) -> list[float]:
"""Compute sigmoid-scaled cosine similarity between image and text prompts."""
if not text_prompts or self._model is None:
return [0.0] * max(len(text_prompts), 1)
inputs = self._processor(
text=text_prompts,
images=image,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model(**inputs)
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
similarities = (image_embeds @ text_embeds.T).squeeze(0)
raw_scores = similarities.cpu().tolist()
if isinstance(raw_scores, float):
raw_scores = [raw_scores]
return [self._sigmoid(s) for s in raw_scores]
def _compute_contrastive_score(
self,
image: Image.Image,
positive_prompts: list[str],
negative_prompts: list[str],
) -> float:
"""Score a dimension using softmax contrastive comparison.
Passes ALL prompts (positive + negative) in a single forward pass,
then computes the probability mass on positive prompts vs negative
prompts using softmax. This gives much better signal than separate
forward passes with sigmoid scaling.
Returns a score in [0, 1] where 1 = image strongly matches positive.
"""
if not positive_prompts:
return 0.5
all_prompts = positive_prompts + negative_prompts
n_pos = len(positive_prompts)
if self._model is None:
return 0.5
inputs = self._processor(
text=all_prompts,
images=image,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model(**inputs)
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
similarities = (image_embeds @ text_embeds.T).squeeze(0)
# Softmax over all prompts — temperature controls peakedness
probs = torch.softmax(similarities / 0.01, dim=-1)
prob_list = probs.cpu().tolist()
if isinstance(prob_list, float):
prob_list = [prob_list]
# Sum probability mass on positive prompts
pos_mass = sum(prob_list[:n_pos])
return float(pos_mass)
def score_image(
self,
image_path: str | Path,
entity_id: str = "",
race: str = "",
gender: str = "",
combat_type: str = "",
entity_description: str = "",
) -> LocalScorerResult:
"""Score a sprite image on all 10 dimensions."""
image = Image.open(image_path).convert("RGB")
scores: dict[str, float] = {}
for dim, queries in DIMENSION_QUERIES.items():
pos = list(queries["positive"])
neg = list(queries["negative"])
if dim == "subject_type" and entity_description:
pos = [entity_description, f"a {entity_description}"]
if dim == "race_accuracy" and race in RACE_QUERIES:
pos = RACE_QUERIES[race]["positive"]
neg = RACE_QUERIES[race]["negative"]
if dim == "gender_accuracy" and gender in GENDER_QUERIES:
pos = GENDER_QUERIES[gender]["positive"]
neg = GENDER_QUERIES[gender]["negative"]
if dim == "equipment_accuracy" and combat_type:
pos = EQUIPMENT_QUERIES.get(combat_type, [])
neg = ["wrong weapon type unarmed empty hands"]
scores[dim] = self._compute_contrastive_score(image, pos, neg)
return LocalScorerResult(scores=scores)
def score_batch(
self,
items: list[dict],
) -> list[LocalScorerResult]:
"""Score multiple images.
Each item dict has keys: image_path, entity_id, race, gender,
combat_type, entity_description.
"""
return [
self.score_image(
image_path=item["image_path"],
entity_id=item.get("entity_id", ""),
race=item.get("race", ""),
gender=item.get("gender", ""),
combat_type=item.get("combat_type", ""),
entity_description=item.get("entity_description", ""),
)
for item in items
]