370 lines
13 KiB
Python
370 lines
13 KiB
Python
"""Local sprite scoring using SigLIP2 zero-shot classification via model-boss.
|
|
|
|
Replaces Sonnet API calls with local GPU inference for 10-dimension sprite scoring.
|
|
Uses the same SigLIP2 model and approach as @imajin's imajin-semantic service,
|
|
but loaded directly with a model-boss GPU lease.
|
|
|
|
Calibration against Sonnet's scores is done via text prompt tuning — adjusting
|
|
the positive/negative text queries per dimension until local scores correlate
|
|
with Sonnet scores (Pearson r > 0.7 per dimension).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import math
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import AutoModel, AutoProcessor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MODEL_NAME = "google/siglip2-so400m-patch14-384"
|
|
VRAM_MB = 2048
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Per-dimension text queries for zero-shot classification.
|
|
# Each dimension has positive prompts (what we want) and negative prompts
|
|
# (what we don't want). Score = max(pos_similarity) - max(neg_similarity),
|
|
# rescaled to 0-1.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DIMENSION_QUERIES: dict[str, dict[str, list[str]]] = {
|
|
"camera_angle": {
|
|
"positive": [
|
|
"isometric 3/4 elevated camera angle looking down at character",
|
|
"game sprite seen from above at 45 degree angle",
|
|
"elevated three-quarter overhead view of character",
|
|
],
|
|
"negative": [
|
|
"front-facing portrait eye-level view",
|
|
"top-down bird's eye view looking straight down",
|
|
"side view profile of character",
|
|
],
|
|
},
|
|
"facing_direction": {
|
|
"positive": [
|
|
"character facing bottom-left walking southwest",
|
|
"figure angled toward lower-left corner of image",
|
|
"rear three-quarter view character walking away to the left",
|
|
],
|
|
"negative": [
|
|
"character facing forward toward the camera",
|
|
"character facing right walking to the right",
|
|
"front view portrait looking at viewer",
|
|
],
|
|
},
|
|
"composition": {
|
|
"positive": [
|
|
"single character game sprite clean silhouette isolated",
|
|
"one figure centered in frame with clear outline",
|
|
"single game unit on plain background",
|
|
],
|
|
"negative": [
|
|
"multiple characters crowd group scene",
|
|
"cluttered scene with many objects",
|
|
"character sheet turnaround reference page multiple poses",
|
|
],
|
|
},
|
|
"subject_type": {
|
|
"positive": [], # Filled dynamically per entity
|
|
"negative": [
|
|
"abstract shape generic blob formless",
|
|
"empty background no subject",
|
|
],
|
|
},
|
|
"race_accuracy": {
|
|
"positive": [], # Filled dynamically per entity
|
|
"negative": [], # Filled dynamically per entity
|
|
},
|
|
"gender_accuracy": {
|
|
"positive": [], # Filled dynamically per entity
|
|
"negative": [], # Filled dynamically per entity
|
|
},
|
|
"equipment_accuracy": {
|
|
"positive": [], # Filled dynamically per entity
|
|
"negative": [], # Filled dynamically per entity
|
|
},
|
|
"pose_quality": {
|
|
"positive": [
|
|
"full body character visible from head to feet standing dynamic pose",
|
|
"complete character sprite feet visible action-ready stance",
|
|
"game unit full body not cropped",
|
|
],
|
|
"negative": [
|
|
"character cropped at knees or waist",
|
|
"stiff T-pose or A-pose mannequin",
|
|
"only head and shoulders portrait bust",
|
|
],
|
|
},
|
|
"background_compliance": {
|
|
"positive": [
|
|
"solid plain simple background uniform color",
|
|
"character on flat clean background easy to separate",
|
|
"simple studio background solid color no scenery",
|
|
],
|
|
"negative": [
|
|
"brown terrain ground landscape background",
|
|
"grey studio backdrop with shadows",
|
|
"scenery forest mountains sky behind character",
|
|
],
|
|
},
|
|
"art_style": {
|
|
"positive": [
|
|
"hand-painted digital fantasy game art bold colors",
|
|
"Warcraft III style painted game sprite rich saturated",
|
|
"stylized fantasy RPG character art bold shapes",
|
|
],
|
|
"negative": [
|
|
"photorealistic photograph realistic portrait",
|
|
"anime manga cartoon japanese style",
|
|
"pixel art retro 8-bit sprite",
|
|
],
|
|
},
|
|
}
|
|
|
|
# Dynamic query builders for entity-specific dimensions
|
|
RACE_QUERIES: dict[str, dict[str, list[str]]] = {
|
|
"dwarves": {
|
|
"positive": [
|
|
"short stocky dwarf character wide body small stature",
|
|
"fantasy dwarf short and broad muscular",
|
|
"dwarven proportions compact heavy build",
|
|
],
|
|
"negative": [
|
|
"tall slender human normal proportions",
|
|
"tall elf graceful thin build",
|
|
"large orc green skin tusks",
|
|
],
|
|
},
|
|
}
|
|
|
|
GENDER_QUERIES: dict[str, dict[str, list[str]]] = {
|
|
"male": {
|
|
"positive": [
|
|
"male character masculine build thick beard",
|
|
"male warrior with facial hair",
|
|
],
|
|
"negative": [
|
|
"female character feminine build",
|
|
"woman with braided hair no beard",
|
|
],
|
|
},
|
|
"female": {
|
|
"positive": [
|
|
"female character feminine build braided hair no beard",
|
|
"woman warrior sturdy feminine",
|
|
],
|
|
"negative": [
|
|
"male character with thick beard masculine",
|
|
"bearded man warrior",
|
|
],
|
|
},
|
|
}
|
|
|
|
EQUIPMENT_QUERIES: dict[str, list[str]] = {
|
|
"melee": ["warrior holding sword or axe heavy armor shield"],
|
|
"ranged": ["archer holding bow and arrows quiver crossbow"],
|
|
"cavalry": ["rider mounted on horse warhorse"],
|
|
"civilian": ["civilian carrying tools supplies no armor"],
|
|
"specialist": ["character with specialized magical or tactical equipment"],
|
|
"siege": ["siege engine catapult war machine heavy wood iron"],
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class LocalScorerResult:
|
|
"""Scores from local SigLIP2 classification."""
|
|
scores: dict[str, float]
|
|
raw_similarities: dict[str, dict[str, float]] = field(default_factory=dict)
|
|
|
|
|
|
class LocalSpriteScorer:
|
|
"""Score sprite images locally using SigLIP2 zero-shot classification.
|
|
|
|
Acquires a model-boss GPU lease, loads SigLIP2, and scores images
|
|
by comparing them against dimension-specific text queries.
|
|
"""
|
|
|
|
def __init__(self, device: str = "cpu"):
|
|
self._device = device
|
|
self._model = None
|
|
self._processor = None
|
|
|
|
async def load(self) -> None:
|
|
"""Load SigLIP2 model using model-boss GPU lease."""
|
|
try:
|
|
from model_boss.client import InferenceClient
|
|
client = InferenceClient(
|
|
client_id="sprite-scorer",
|
|
auto_start_services=False,
|
|
)
|
|
lease = await client.acquire_lease(
|
|
model_id=f"service:{MODEL_NAME}",
|
|
vram_mb=VRAM_MB,
|
|
priority="normal",
|
|
)
|
|
self._device = f"cuda:{lease['gpu_index']}"
|
|
logger.info("Acquired GPU lease: %s", self._device)
|
|
except Exception as exc:
|
|
logger.warning("model-boss lease failed (%s), using %s", exc, self._device)
|
|
|
|
logger.info("Loading SigLIP2 on %s...", self._device)
|
|
self._processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
|
self._model = AutoModel.from_pretrained(MODEL_NAME).to(self._device)
|
|
self._model.requires_grad_(False)
|
|
logger.info("SigLIP2 loaded")
|
|
|
|
def load_sync(self, device: str | None = None) -> None:
|
|
"""Synchronous model loading (for non-async contexts)."""
|
|
if device:
|
|
self._device = device
|
|
logger.info("Loading SigLIP2 on %s...", self._device)
|
|
self._processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
|
self._model = AutoModel.from_pretrained(MODEL_NAME).to(self._device)
|
|
self._model.requires_grad_(False)
|
|
logger.info("SigLIP2 loaded")
|
|
|
|
@staticmethod
|
|
def _sigmoid(x: float, center: float = 0.25, scale: float = 10.0) -> float:
|
|
return 1.0 / (1.0 + math.exp(-scale * (x - center)))
|
|
|
|
def _compute_raw_similarity(
|
|
self,
|
|
image: Image.Image,
|
|
text_prompts: list[str],
|
|
) -> list[float]:
|
|
"""Compute sigmoid-scaled cosine similarity between image and text prompts."""
|
|
if not text_prompts or self._model is None:
|
|
return [0.0] * max(len(text_prompts), 1)
|
|
|
|
inputs = self._processor(
|
|
text=text_prompts,
|
|
images=image,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
)
|
|
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
|
|
with torch.no_grad():
|
|
outputs = self._model(**inputs)
|
|
image_embeds = outputs.image_embeds
|
|
text_embeds = outputs.text_embeds
|
|
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
|
similarities = (image_embeds @ text_embeds.T).squeeze(0)
|
|
|
|
raw_scores = similarities.cpu().tolist()
|
|
if isinstance(raw_scores, float):
|
|
raw_scores = [raw_scores]
|
|
return [self._sigmoid(s) for s in raw_scores]
|
|
|
|
def _compute_contrastive_score(
|
|
self,
|
|
image: Image.Image,
|
|
positive_prompts: list[str],
|
|
negative_prompts: list[str],
|
|
) -> float:
|
|
"""Score a dimension using softmax contrastive comparison.
|
|
|
|
Passes ALL prompts (positive + negative) in a single forward pass,
|
|
then computes the probability mass on positive prompts vs negative
|
|
prompts using softmax. This gives much better signal than separate
|
|
forward passes with sigmoid scaling.
|
|
|
|
Returns a score in [0, 1] where 1 = image strongly matches positive.
|
|
"""
|
|
if not positive_prompts:
|
|
return 0.5
|
|
|
|
all_prompts = positive_prompts + negative_prompts
|
|
n_pos = len(positive_prompts)
|
|
|
|
if self._model is None:
|
|
return 0.5
|
|
|
|
inputs = self._processor(
|
|
text=all_prompts,
|
|
images=image,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
)
|
|
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
|
|
with torch.no_grad():
|
|
outputs = self._model(**inputs)
|
|
image_embeds = outputs.image_embeds
|
|
text_embeds = outputs.text_embeds
|
|
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
|
similarities = (image_embeds @ text_embeds.T).squeeze(0)
|
|
|
|
# Softmax over all prompts — temperature controls peakedness
|
|
probs = torch.softmax(similarities / 0.01, dim=-1)
|
|
|
|
prob_list = probs.cpu().tolist()
|
|
if isinstance(prob_list, float):
|
|
prob_list = [prob_list]
|
|
|
|
# Sum probability mass on positive prompts
|
|
pos_mass = sum(prob_list[:n_pos])
|
|
return float(pos_mass)
|
|
|
|
def score_image(
|
|
self,
|
|
image_path: str | Path,
|
|
entity_id: str = "",
|
|
race: str = "",
|
|
gender: str = "",
|
|
combat_type: str = "",
|
|
entity_description: str = "",
|
|
) -> LocalScorerResult:
|
|
"""Score a sprite image on all 10 dimensions."""
|
|
image = Image.open(image_path).convert("RGB")
|
|
scores: dict[str, float] = {}
|
|
|
|
for dim, queries in DIMENSION_QUERIES.items():
|
|
pos = list(queries["positive"])
|
|
neg = list(queries["negative"])
|
|
|
|
if dim == "subject_type" and entity_description:
|
|
pos = [entity_description, f"a {entity_description}"]
|
|
|
|
if dim == "race_accuracy" and race in RACE_QUERIES:
|
|
pos = RACE_QUERIES[race]["positive"]
|
|
neg = RACE_QUERIES[race]["negative"]
|
|
|
|
if dim == "gender_accuracy" and gender in GENDER_QUERIES:
|
|
pos = GENDER_QUERIES[gender]["positive"]
|
|
neg = GENDER_QUERIES[gender]["negative"]
|
|
|
|
if dim == "equipment_accuracy" and combat_type:
|
|
pos = EQUIPMENT_QUERIES.get(combat_type, [])
|
|
neg = ["wrong weapon type unarmed empty hands"]
|
|
|
|
scores[dim] = self._compute_contrastive_score(image, pos, neg)
|
|
|
|
return LocalScorerResult(scores=scores)
|
|
|
|
def score_batch(
|
|
self,
|
|
items: list[dict],
|
|
) -> list[LocalScorerResult]:
|
|
"""Score multiple images.
|
|
|
|
Each item dict has keys: image_path, entity_id, race, gender,
|
|
combat_type, entity_description.
|
|
"""
|
|
return [
|
|
self.score_image(
|
|
image_path=item["image_path"],
|
|
entity_id=item.get("entity_id", ""),
|
|
race=item.get("race", ""),
|
|
gender=item.get("gender", ""),
|
|
combat_type=item.get("combat_type", ""),
|
|
entity_description=item.get("entity_description", ""),
|
|
)
|
|
for item in items
|
|
]
|