magicciv/tools/sprite-generation/engine/generator.py
Claude Code e9df207641 chore(engine): 🔧 Update engine build configuration for error handling
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-03-29 23:39:18 -07:00

416 lines
16 KiB
Python

"""Sprite generation via model-boss queue + Redis pubsub.
Two-phase architecture:
1. SUBMIT — queue generation requests to model-boss (returns immediately)
2. COLLECT — await results via Redis pubsub as GPU completes them
The CLI can submit hundreds of requests in seconds. model-boss queues them
internally and processes when GPU/VRAM is available. Results arrive via
Redis pubsub — no polling, no blocking HTTP connections.
Request IDs are stored in the variants table (job_id column) so pending
work survives process restarts.
## Variant job_status lifecycle:
submitted → completed (image received and saved)
submitted → failed (generation error or timeout)
"""
from __future__ import annotations
import asyncio
import base64
import json
import logging
import math
import random
from pathlib import Path
from model_boss import InferenceClient
from engine.prompts import (
get_generation_size, get_negative, get_variant_modifier,
compose_prompt, _extract_unit_class, _unit_classes,
)
from engine.registry import SpriteRegistry
logger = logging.getLogger(__name__)
def _extract_image_b64(result: dict) -> str | None:
"""Extract base64 image data from model-boss result.
Handles both response formats:
- InferenceClient queue: {"result": {"data": [{"b64_json": "..."}]}}
- Direct HTTP API: {"images": ["..."], "durationMs": N}
"""
# Queue result wraps in "result" key
inner = result.get("result", result)
if isinstance(inner, dict):
data_list = inner.get("data", [])
if data_list and isinstance(data_list[0], dict):
b64 = data_list[0].get("b64_json")
if b64:
return b64
images = inner.get("images", [])
if images:
return images[0]
return None
class SpriteGenerator:
def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path):
self.model = config["model"]
self.defaults = config["defaults"]
self.registry = registry
self.raw_dir = raw_dir
self.raw_dir.mkdir(parents=True, exist_ok=True)
self._client = InferenceClient(
client_id="sprite-generator",
default_priority="normal",
timeout=600.0,
)
# ------------------------------------------------------------------
# Prompt helpers
# ------------------------------------------------------------------
@staticmethod
def _compose_fresh(category: str, entity_id: str) -> tuple[str, str]:
"""Re-compose prompt + negative from current YAML for a sprite.
Rebuilds entity_data from entity_id at submission time so YAML changes
take effect without requiring a rescan. For units, extracts race, gender,
and combat_type from the entity_id (e.g. "archers_dwarves_m").
Returns (prompt, negative) or ("", "") if reconstruction fails.
"""
if not entity_id:
return "", ""
if category == "units":
# Extract unit class → combat_type
unit_class = _extract_unit_class(entity_id)
combat_type = ""
if unit_class:
combat_type = _unit_classes().get(unit_class, {}).get("combat_type", "")
# Parse race from entity_id (e.g. "_dwarves")
race = None
for r in ("dwarves", "humans", "high_elves", "orcs"):
if f"_{r}" in entity_id:
race = r
break
# Parse gender from entity_id
gender = None
if entity_id.endswith("_m"):
gender = "m"
elif entity_id.endswith("_f"):
gender = "f"
entity_data = {
"entity_id": entity_id,
"combat_type": combat_type,
"description": "",
"keywords": [],
}
dimensions = {
k: v for k, v in {
"race": race,
"gender": gender,
"quality": 2, # default mid-tier for prompt composition
}.items() if v is not None
}
prompt = compose_prompt(category, entity_data, dimensions)
negative = get_negative(category, combat_type=combat_type, race=race or "")
return prompt, negative
# Non-units: use entity_id as base description
entity_data = {"entity_id": entity_id, "name": entity_id, "description": "", "keywords": []}
prompt = compose_prompt(category, entity_data)
negative = get_negative(category)
return prompt, negative
# ------------------------------------------------------------------
# Seed + modifier selection
# ------------------------------------------------------------------
def _select_seeds(self, category: str, entity_id: str, n: int) -> list[int]:
"""Select n seeds using a 70/30 proven/random split.
Proven seeds come from the seed_pool table (accumulated from past
high-scoring variants). Falls back to fully random when pool is empty
or has too few entries to fill the proven quota.
Also explores seed neighbors (±1..3 from the best proven seed) to
sample nearby latent-space regions systematically.
"""
proven = self.registry.get_proven_seeds(
category=category,
entity_id=entity_id,
limit=40,
min_quality=65.0,
)
proven_target = math.ceil(n * 0.7)
selected: list[int] = []
if proven:
weights = [p["avg_quality"] for p in proven]
total_w = sum(weights)
norm = [w / total_w for w in weights]
chosen = random.choices(proven, weights=norm, k=min(proven_target, len(proven)))
selected = [p["seed"] for p in chosen]
# Neighbor exploration: ±1..3 from best seed
best_seed = max(proven, key=lambda p: p["avg_quality"])["seed"]
for delta in range(1, 4):
if len(selected) < proven_target:
selected.append((best_seed + delta) % (2**32))
while len(selected) < n:
selected.append(random.randint(0, 2**32 - 1))
return selected[:n]
def _select_modifiers(self, entity_id: str, category: str, n: int) -> list[int]:
"""Return modifier indices for n variants.
When generation_hints has ≥20 samples, weights 60% toward historically
passing modifier indices. Otherwise uses the standard sequential cycle.
"""
hints = self.registry.get_generation_hints(entity_id, category)
if hints and hints.get("best_modifier_indices") and hints["sample_count"] >= 20:
good_mods: list[int] = json.loads(hints["best_modifier_indices"])
if good_mods:
good_count = round(n * 0.6)
result = [random.choice(good_mods) for _ in range(good_count)]
result += list(range(n - good_count))
return result
return list(range(n))
# ------------------------------------------------------------------
# Phase 1: SUBMIT — queue requests, return immediately
# ------------------------------------------------------------------
async def submit_one(
self,
sprite_id: str,
prompt: str,
negative: str,
width: int,
height: int,
seed: int,
prompt_modifier: str = "",
priority: str = "normal",
dimension_id: int | None = None,
guidance_scale: float | None = None,
) -> tuple[int, str] | None:
"""Submit a single generation request to model-boss queue.
Creates variant in DB with job_status='submitted' and stores request_id.
Returns (variant_id, request_id) or None on submission failure.
"""
full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p)
effective_guidance = (
guidance_scale if guidance_scale is not None
else self.defaults.get("guidance_scale", 7.5)
)
body_fields = {
"prompt": full_prompt,
"negative_prompt": negative,
"width": width,
"height": height,
"steps": self.defaults.get("steps", 25),
"guidance_scale": effective_guidance,
"seed": seed,
"n": 1,
}
try:
request_id = await self._client.submit(
model=self.model,
messages=[], # unused for image generation — body fields carry the payload
endpoint="generate-image",
priority=priority,
keep_alive=300,
context={
"label": sprite_id,
"description": f"seed={seed}",
"tags": ["sprite-gen"],
},
**body_fields,
)
except Exception as exc:
logger.warning("Submit failed for %s: %s: %r", sprite_id, type(exc).__name__, exc)
print(f"[submit-error] {sprite_id}: {type(exc).__name__}: {exc!r}")
return None
# Create variant in DB — submitted, awaiting result
variant_id = self.registry.add_variant(
sprite_id=sprite_id,
seed=seed,
dimension_id=dimension_id,
prompt_modifier=prompt_modifier,
job_id=request_id,
model=self.model,
prompt_used=full_prompt,
negative_used=negative,
guidance_scale=effective_guidance,
steps=self.defaults.get("steps", 25),
prompt_author="claude-opus-4-6",
)
return variant_id, request_id
async def submit_batch(
self,
sprite_ids: list[str],
variants_per: int = 4,
priority: str = "normal",
) -> int:
"""Submit generation requests for all sprites. Returns immediately.
Variant rows created in DB with job_status='submitted'.
Use collect_pending() to retrieve results.
"""
run_id = self.registry.start_run(variants_per=variants_per)
submitted = 0
for sprite_id in sprite_ids:
sprite = self.registry.get_sprite(sprite_id)
if not sprite:
continue
category = sprite["category"]
entity_id = sprite.get("entity_id", "")
gen_w, gen_h = get_generation_size(category)
# Re-compose prompt from current YAML every submission — never use stale DB prompt.
# Sprite records store prompts from scan time; YAML changes need live recomposition.
prompt, negative = self._compose_fresh(category, entity_id)
if not prompt:
prompt = sprite["prompt"] or ""
negative = sprite["negative_prompt"] or get_negative(category)
width = sprite.get("gen_width") or gen_w
height = sprite.get("gen_height") or gen_h
seeds = self._select_seeds(category, entity_id, variants_per)
modifier_indices = self._select_modifiers(entity_id, category, variants_per)
# Adaptive guidance: use generation_hints when ≥10 passing samples exist
hints = self.registry.get_generation_hints(entity_id, category)
if hints and hints.get("best_guidance") and hints["sample_count"] >= 10:
guidance = max(6.5, min(9.0, hints["best_guidance"]))
else:
guidance = self.defaults.get("guidance_scale", 7.5)
for i in range(variants_per):
seed = seeds[i]
modifier = get_variant_modifier(modifier_indices[i])
result = await self.submit_one(
sprite_id=sprite_id,
prompt=prompt,
negative=negative,
width=width,
height=height,
seed=seed,
prompt_modifier=modifier,
priority=priority,
guidance_scale=guidance,
)
if result:
submitted += 1
# Update run with total
with self.registry.conn:
self.registry.conn.execute(
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
(submitted, run_id),
)
print(f"Queued {submitted} requests ({len(sprite_ids)} sprites x {variants_per} variants)")
return submitted
# ------------------------------------------------------------------
# Phase 2: COLLECT — await results via Redis pubsub
# ------------------------------------------------------------------
async def collect_one(self, variant_id: int, request_id: str, sprite_id: str = "") -> bool:
"""Wait for one generation result via Redis pubsub. Save image to disk.
Returns True on success, False on failure.
"""
try:
result = await self._client.wait_for_result(request_id, timeout=600)
except Exception as exc:
logger.warning("Collect failed for variant %d: %s", variant_id, exc)
self.registry.update_variant_status(variant_id, "failed")
print(f"[fail] variant {variant_id}: {exc}")
return False
if result.get("status") == "failed":
error = result.get("error", "unknown error")
self.registry.update_variant_status(variant_id, "failed")
print(f"[fail] variant {variant_id}: {error}")
return False
image_b64 = _extract_image_b64(result)
if not image_b64:
print(f"[fail] variant {variant_id}: no image data (keys: {list(result.keys())})")
self.registry.update_variant_status(variant_id, "failed")
return False
image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64
inner = result.get("result", result)
duration = inner.get("durationMs") or inner.get("duration_ms") if isinstance(inner, dict) else None
# Save raw image
safe_name = sprite_id.replace("/", "_") if sprite_id else f"v{variant_id}"
filename = f"{safe_name}_{variant_id}.png"
raw_path = self.raw_dir / filename
raw_path.write_bytes(image_bytes)
# Update variant as completed
self.registry.update_variant_status(
variant_id, "completed",
raw_path=str(raw_path),
generation_ms=duration,
)
print(f"[done] variant {variant_id} ({sprite_id}) -> {filename}")
return True
async def collect_pending(self, on_complete=None) -> int:
"""Collect ALL pending generation results via Redis pubsub.
Awaits all submitted variants concurrently. As each result arrives,
saves the image and optionally calls on_complete(variant_id, sprite_id).
Returns count of successfully collected variants.
"""
pending = self.registry.conn.execute(
"SELECT v.id, v.job_id, v.sprite_id FROM variants v "
"WHERE v.job_status = 'submitted' AND v.job_id IS NOT NULL"
).fetchall()
if not pending:
return 0
print(f"Awaiting {len(pending)} pending results via Redis pubsub...")
async def _collect_and_notify(row):
success = await self.collect_one(row["id"], row["job_id"], row["sprite_id"])
if success and on_complete:
await on_complete(row["id"], row["sprite_id"])
return success
tasks = [_collect_and_notify(row) for row in pending]
results = await asyncio.gather(*tasks, return_exceptions=True)
return sum(1 for r in results if r is True)
async def close(self) -> None:
"""Clean up client resources."""
if hasattr(self._client, 'close'):
await self._client.close()