magicciv/tools/sprite-generation/engine/generator.py

342 lines
12 KiB
Python

"""Sprite generation via model-boss queue + Redis pubsub.
Two-phase architecture:
1. SUBMIT — queue generation requests to model-boss (returns immediately)
2. COLLECT — await results via Redis pubsub as GPU completes them
The CLI can submit hundreds of requests in seconds. model-boss queues them
internally and processes when GPU/VRAM is available. Results arrive via
Redis pubsub — no polling, no blocking HTTP connections.
Request IDs are stored in the variants table (job_id column) so pending
work survives process restarts.
## Variant job_status lifecycle:
submitted → completed (image received and saved)
submitted → failed (generation error or timeout)
"""
from __future__ import annotations
import asyncio
import base64
import json
import logging
import random
from pathlib import Path
from model_boss import InferenceClient
from engine.prompts import (
get_generation_size, get_negative, get_variant_modifier,
compose_prompt, _extract_unit_class, _unit_classes,
)
from engine.registry import SpriteRegistry
logger = logging.getLogger(__name__)
def _extract_image_b64(result: dict) -> str | None:
"""Extract base64 image data from model-boss result.
Handles both response formats:
- InferenceClient queue: {"result": {"data": [{"b64_json": "..."}]}}
- Direct HTTP API: {"images": ["..."], "durationMs": N}
"""
# Queue result wraps in "result" key
inner = result.get("result", result)
if isinstance(inner, dict):
data_list = inner.get("data", [])
if data_list and isinstance(data_list[0], dict):
b64 = data_list[0].get("b64_json")
if b64:
return b64
images = inner.get("images", [])
if images:
return images[0]
return None
class SpriteGenerator:
def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path):
self.model = config["model"]
self.defaults = config["defaults"]
self.registry = registry
self.raw_dir = raw_dir
self.raw_dir.mkdir(parents=True, exist_ok=True)
self._client = InferenceClient(
client_id="sprite-generator",
default_priority="normal",
timeout=600.0,
)
# ------------------------------------------------------------------
# Prompt helpers
# ------------------------------------------------------------------
@staticmethod
def _compose_fresh(category: str, entity_id: str) -> tuple[str, str]:
"""Re-compose prompt + negative from current YAML for a sprite.
Rebuilds entity_data from entity_id at submission time so YAML changes
take effect without requiring a rescan. For units, extracts race, gender,
and combat_type from the entity_id (e.g. "archers_dwarves_m").
Returns (prompt, negative) or ("", "") if reconstruction fails.
"""
if not entity_id:
return "", ""
if category == "units":
# Extract unit class → combat_type
unit_class = _extract_unit_class(entity_id)
combat_type = ""
if unit_class:
combat_type = _unit_classes().get(unit_class, {}).get("combat_type", "")
# Parse race from entity_id (e.g. "_dwarves")
race = None
for r in ("dwarves", "humans", "high_elves", "orcs"):
if f"_{r}" in entity_id:
race = r
break
# Parse gender from entity_id
gender = None
if entity_id.endswith("_m"):
gender = "m"
elif entity_id.endswith("_f"):
gender = "f"
entity_data = {
"entity_id": entity_id,
"combat_type": combat_type,
"description": "",
"keywords": [],
}
dimensions = {
k: v for k, v in {
"race": race,
"gender": gender,
"quality": 2, # default mid-tier for prompt composition
}.items() if v is not None
}
prompt = compose_prompt(category, entity_data, dimensions)
negative = get_negative(category, combat_type=combat_type)
return prompt, negative
# Non-units: use entity_id as base description
entity_data = {"entity_id": entity_id, "name": entity_id, "description": "", "keywords": []}
prompt = compose_prompt(category, entity_data)
negative = get_negative(category)
return prompt, negative
# ------------------------------------------------------------------
# Phase 1: SUBMIT — queue requests, return immediately
# ------------------------------------------------------------------
async def submit_one(
self,
sprite_id: str,
prompt: str,
negative: str,
width: int,
height: int,
seed: int,
prompt_modifier: str = "",
priority: str = "normal",
dimension_id: int | None = None,
) -> tuple[int, str] | None:
"""Submit a single generation request to model-boss queue.
Creates variant in DB with job_status='submitted' and stores request_id.
Returns (variant_id, request_id) or None on submission failure.
"""
full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p)
body_fields = {
"prompt": full_prompt,
"negative_prompt": negative,
"width": width,
"height": height,
"steps": self.defaults.get("steps", 25),
"guidance_scale": self.defaults.get("guidance_scale", 7.5),
"seed": seed,
"n": 1,
}
try:
request_id = await self._client.submit(
model=self.model,
messages=[], # unused for image generation — body fields carry the payload
endpoint="generate-image",
priority=priority,
keep_alive=300,
context={
"label": sprite_id,
"description": f"seed={seed}",
"tags": ["sprite-gen"],
},
**body_fields,
)
except Exception as exc:
logger.warning("Submit failed for %s: %s", sprite_id, exc)
print(f"[submit-error] {sprite_id}: {exc}")
return None
# Create variant in DB — submitted, awaiting result
variant_id = self.registry.add_variant(
sprite_id=sprite_id,
seed=seed,
dimension_id=dimension_id,
prompt_modifier=prompt_modifier,
job_id=request_id,
model=self.model,
prompt_used=full_prompt,
negative_used=negative,
guidance_scale=self.defaults.get("guidance_scale", 7.5),
steps=self.defaults.get("steps", 25),
prompt_author="claude-opus-4-6",
)
return variant_id, request_id
async def submit_batch(
self,
sprite_ids: list[str],
variants_per: int = 4,
priority: str = "normal",
) -> int:
"""Submit generation requests for all sprites. Returns immediately.
Variant rows created in DB with job_status='submitted'.
Use collect_pending() to retrieve results.
"""
run_id = self.registry.start_run(variants_per=variants_per)
submitted = 0
for sprite_id in sprite_ids:
sprite = self.registry.get_sprite(sprite_id)
if not sprite:
continue
category = sprite["category"]
entity_id = sprite.get("entity_id", "")
gen_w, gen_h = get_generation_size(category)
# Re-compose prompt from current YAML every submission — never use stale DB prompt.
# Sprite records store prompts from scan time; YAML changes need live recomposition.
prompt, negative = self._compose_fresh(category, entity_id)
if not prompt:
prompt = sprite["prompt"] or ""
negative = sprite["negative_prompt"] or get_negative(category)
width = sprite.get("gen_width") or gen_w
height = sprite.get("gen_height") or gen_h
for i in range(variants_per):
seed = random.randint(0, 2**32 - 1)
modifier = get_variant_modifier(i)
result = await self.submit_one(
sprite_id=sprite_id,
prompt=prompt,
negative=negative,
width=width,
height=height,
seed=seed,
prompt_modifier=modifier,
priority=priority,
)
if result:
submitted += 1
# Update run with total
with self.registry.conn:
self.registry.conn.execute(
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
(submitted, run_id),
)
print(f"Queued {submitted} requests ({len(sprite_ids)} sprites x {variants_per} variants)")
return submitted
# ------------------------------------------------------------------
# Phase 2: COLLECT — await results via Redis pubsub
# ------------------------------------------------------------------
async def collect_one(self, variant_id: int, request_id: str, sprite_id: str = "") -> bool:
"""Wait for one generation result via Redis pubsub. Save image to disk.
Returns True on success, False on failure.
"""
try:
result = await self._client.wait_for_result(request_id, timeout=600)
except Exception as exc:
logger.warning("Collect failed for variant %d: %s", variant_id, exc)
self.registry.update_variant_status(variant_id, "failed")
print(f"[fail] variant {variant_id}: {exc}")
return False
if result.get("status") == "failed":
error = result.get("error", "unknown error")
self.registry.update_variant_status(variant_id, "failed")
print(f"[fail] variant {variant_id}: {error}")
return False
image_b64 = _extract_image_b64(result)
if not image_b64:
print(f"[fail] variant {variant_id}: no image data (keys: {list(result.keys())})")
self.registry.update_variant_status(variant_id, "failed")
return False
image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64
inner = result.get("result", result)
duration = inner.get("durationMs") or inner.get("duration_ms") if isinstance(inner, dict) else None
# Save raw image
safe_name = sprite_id.replace("/", "_") if sprite_id else f"v{variant_id}"
filename = f"{safe_name}_{variant_id}.png"
raw_path = self.raw_dir / filename
raw_path.write_bytes(image_bytes)
# Update variant as completed
self.registry.update_variant_status(
variant_id, "completed",
raw_path=str(raw_path),
generation_ms=duration,
)
print(f"[done] variant {variant_id} ({sprite_id}) -> {filename}")
return True
async def collect_pending(self, on_complete=None) -> int:
"""Collect ALL pending generation results via Redis pubsub.
Awaits all submitted variants concurrently. As each result arrives,
saves the image and optionally calls on_complete(variant_id, sprite_id).
Returns count of successfully collected variants.
"""
pending = self.registry.conn.execute(
"SELECT v.id, v.job_id, v.sprite_id FROM variants v "
"WHERE v.job_status = 'submitted' AND v.job_id IS NOT NULL"
).fetchall()
if not pending:
return 0
print(f"Awaiting {len(pending)} pending results via Redis pubsub...")
async def _collect_and_notify(row):
success = await self.collect_one(row["id"], row["job_id"], row["sprite_id"])
if success and on_complete:
await on_complete(row["id"], row["sprite_id"])
return success
tasks = [_collect_and_notify(row) for row in pending]
results = await asyncio.gather(*tasks, return_exceptions=True)
return sum(1 for r in results if r is True)
async def close(self) -> None:
"""Clean up client resources."""
if hasattr(self._client, 'close'):
await self._client.close()