342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""Sprite generation via model-boss queue + Redis pubsub.
|
|
|
|
Two-phase architecture:
|
|
1. SUBMIT — queue generation requests to model-boss (returns immediately)
|
|
2. COLLECT — await results via Redis pubsub as GPU completes them
|
|
|
|
The CLI can submit hundreds of requests in seconds. model-boss queues them
|
|
internally and processes when GPU/VRAM is available. Results arrive via
|
|
Redis pubsub — no polling, no blocking HTTP connections.
|
|
|
|
Request IDs are stored in the variants table (job_id column) so pending
|
|
work survives process restarts.
|
|
|
|
## Variant job_status lifecycle:
|
|
submitted → completed (image received and saved)
|
|
submitted → failed (generation error or timeout)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import random
|
|
from pathlib import Path
|
|
|
|
from model_boss import InferenceClient
|
|
|
|
from engine.prompts import (
|
|
get_generation_size, get_negative, get_variant_modifier,
|
|
compose_prompt, _extract_unit_class, _unit_classes,
|
|
)
|
|
from engine.registry import SpriteRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _extract_image_b64(result: dict) -> str | None:
|
|
"""Extract base64 image data from model-boss result.
|
|
|
|
Handles both response formats:
|
|
- InferenceClient queue: {"result": {"data": [{"b64_json": "..."}]}}
|
|
- Direct HTTP API: {"images": ["..."], "durationMs": N}
|
|
"""
|
|
# Queue result wraps in "result" key
|
|
inner = result.get("result", result)
|
|
if isinstance(inner, dict):
|
|
data_list = inner.get("data", [])
|
|
if data_list and isinstance(data_list[0], dict):
|
|
b64 = data_list[0].get("b64_json")
|
|
if b64:
|
|
return b64
|
|
images = inner.get("images", [])
|
|
if images:
|
|
return images[0]
|
|
return None
|
|
|
|
|
|
class SpriteGenerator:
|
|
def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path):
|
|
self.model = config["model"]
|
|
self.defaults = config["defaults"]
|
|
self.registry = registry
|
|
self.raw_dir = raw_dir
|
|
self.raw_dir.mkdir(parents=True, exist_ok=True)
|
|
self._client = InferenceClient(
|
|
client_id="sprite-generator",
|
|
default_priority="normal",
|
|
timeout=600.0,
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Prompt helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _compose_fresh(category: str, entity_id: str) -> tuple[str, str]:
|
|
"""Re-compose prompt + negative from current YAML for a sprite.
|
|
|
|
Rebuilds entity_data from entity_id at submission time so YAML changes
|
|
take effect without requiring a rescan. For units, extracts race, gender,
|
|
and combat_type from the entity_id (e.g. "archers_dwarves_m").
|
|
|
|
Returns (prompt, negative) or ("", "") if reconstruction fails.
|
|
"""
|
|
if not entity_id:
|
|
return "", ""
|
|
|
|
if category == "units":
|
|
# Extract unit class → combat_type
|
|
unit_class = _extract_unit_class(entity_id)
|
|
combat_type = ""
|
|
if unit_class:
|
|
combat_type = _unit_classes().get(unit_class, {}).get("combat_type", "")
|
|
|
|
# Parse race from entity_id (e.g. "_dwarves")
|
|
race = None
|
|
for r in ("dwarves", "humans", "high_elves", "orcs"):
|
|
if f"_{r}" in entity_id:
|
|
race = r
|
|
break
|
|
|
|
# Parse gender from entity_id
|
|
gender = None
|
|
if entity_id.endswith("_m"):
|
|
gender = "m"
|
|
elif entity_id.endswith("_f"):
|
|
gender = "f"
|
|
|
|
entity_data = {
|
|
"entity_id": entity_id,
|
|
"combat_type": combat_type,
|
|
"description": "",
|
|
"keywords": [],
|
|
}
|
|
dimensions = {
|
|
k: v for k, v in {
|
|
"race": race,
|
|
"gender": gender,
|
|
"quality": 2, # default mid-tier for prompt composition
|
|
}.items() if v is not None
|
|
}
|
|
prompt = compose_prompt(category, entity_data, dimensions)
|
|
negative = get_negative(category, combat_type=combat_type)
|
|
return prompt, negative
|
|
|
|
# Non-units: use entity_id as base description
|
|
entity_data = {"entity_id": entity_id, "name": entity_id, "description": "", "keywords": []}
|
|
prompt = compose_prompt(category, entity_data)
|
|
negative = get_negative(category)
|
|
return prompt, negative
|
|
|
|
# ------------------------------------------------------------------
|
|
# Phase 1: SUBMIT — queue requests, return immediately
|
|
# ------------------------------------------------------------------
|
|
|
|
async def submit_one(
|
|
self,
|
|
sprite_id: str,
|
|
prompt: str,
|
|
negative: str,
|
|
width: int,
|
|
height: int,
|
|
seed: int,
|
|
prompt_modifier: str = "",
|
|
priority: str = "normal",
|
|
dimension_id: int | None = None,
|
|
) -> tuple[int, str] | None:
|
|
"""Submit a single generation request to model-boss queue.
|
|
|
|
Creates variant in DB with job_status='submitted' and stores request_id.
|
|
Returns (variant_id, request_id) or None on submission failure.
|
|
"""
|
|
full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p)
|
|
|
|
body_fields = {
|
|
"prompt": full_prompt,
|
|
"negative_prompt": negative,
|
|
"width": width,
|
|
"height": height,
|
|
"steps": self.defaults.get("steps", 25),
|
|
"guidance_scale": self.defaults.get("guidance_scale", 7.5),
|
|
"seed": seed,
|
|
"n": 1,
|
|
}
|
|
|
|
try:
|
|
request_id = await self._client.submit(
|
|
model=self.model,
|
|
messages=[], # unused for image generation — body fields carry the payload
|
|
endpoint="generate-image",
|
|
priority=priority,
|
|
keep_alive=300,
|
|
context={
|
|
"label": sprite_id,
|
|
"description": f"seed={seed}",
|
|
"tags": ["sprite-gen"],
|
|
},
|
|
**body_fields,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("Submit failed for %s: %s", sprite_id, exc)
|
|
print(f"[submit-error] {sprite_id}: {exc}")
|
|
return None
|
|
|
|
# Create variant in DB — submitted, awaiting result
|
|
variant_id = self.registry.add_variant(
|
|
sprite_id=sprite_id,
|
|
seed=seed,
|
|
dimension_id=dimension_id,
|
|
prompt_modifier=prompt_modifier,
|
|
job_id=request_id,
|
|
model=self.model,
|
|
prompt_used=full_prompt,
|
|
negative_used=negative,
|
|
guidance_scale=self.defaults.get("guidance_scale", 7.5),
|
|
steps=self.defaults.get("steps", 25),
|
|
prompt_author="claude-opus-4-6",
|
|
)
|
|
|
|
return variant_id, request_id
|
|
|
|
async def submit_batch(
|
|
self,
|
|
sprite_ids: list[str],
|
|
variants_per: int = 4,
|
|
priority: str = "normal",
|
|
) -> int:
|
|
"""Submit generation requests for all sprites. Returns immediately.
|
|
|
|
Variant rows created in DB with job_status='submitted'.
|
|
Use collect_pending() to retrieve results.
|
|
"""
|
|
run_id = self.registry.start_run(variants_per=variants_per)
|
|
submitted = 0
|
|
|
|
for sprite_id in sprite_ids:
|
|
sprite = self.registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
continue
|
|
|
|
category = sprite["category"]
|
|
entity_id = sprite.get("entity_id", "")
|
|
gen_w, gen_h = get_generation_size(category)
|
|
|
|
# Re-compose prompt from current YAML every submission — never use stale DB prompt.
|
|
# Sprite records store prompts from scan time; YAML changes need live recomposition.
|
|
prompt, negative = self._compose_fresh(category, entity_id)
|
|
if not prompt:
|
|
prompt = sprite["prompt"] or ""
|
|
negative = sprite["negative_prompt"] or get_negative(category)
|
|
width = sprite.get("gen_width") or gen_w
|
|
height = sprite.get("gen_height") or gen_h
|
|
|
|
for i in range(variants_per):
|
|
seed = random.randint(0, 2**32 - 1)
|
|
modifier = get_variant_modifier(i)
|
|
|
|
result = await self.submit_one(
|
|
sprite_id=sprite_id,
|
|
prompt=prompt,
|
|
negative=negative,
|
|
width=width,
|
|
height=height,
|
|
seed=seed,
|
|
prompt_modifier=modifier,
|
|
priority=priority,
|
|
)
|
|
if result:
|
|
submitted += 1
|
|
|
|
# Update run with total
|
|
with self.registry.conn:
|
|
self.registry.conn.execute(
|
|
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
|
|
(submitted, run_id),
|
|
)
|
|
|
|
print(f"Queued {submitted} requests ({len(sprite_ids)} sprites x {variants_per} variants)")
|
|
return submitted
|
|
|
|
# ------------------------------------------------------------------
|
|
# Phase 2: COLLECT — await results via Redis pubsub
|
|
# ------------------------------------------------------------------
|
|
|
|
async def collect_one(self, variant_id: int, request_id: str, sprite_id: str = "") -> bool:
|
|
"""Wait for one generation result via Redis pubsub. Save image to disk.
|
|
|
|
Returns True on success, False on failure.
|
|
"""
|
|
try:
|
|
result = await self._client.wait_for_result(request_id, timeout=600)
|
|
except Exception as exc:
|
|
logger.warning("Collect failed for variant %d: %s", variant_id, exc)
|
|
self.registry.update_variant_status(variant_id, "failed")
|
|
print(f"[fail] variant {variant_id}: {exc}")
|
|
return False
|
|
|
|
if result.get("status") == "failed":
|
|
error = result.get("error", "unknown error")
|
|
self.registry.update_variant_status(variant_id, "failed")
|
|
print(f"[fail] variant {variant_id}: {error}")
|
|
return False
|
|
|
|
image_b64 = _extract_image_b64(result)
|
|
if not image_b64:
|
|
print(f"[fail] variant {variant_id}: no image data (keys: {list(result.keys())})")
|
|
self.registry.update_variant_status(variant_id, "failed")
|
|
return False
|
|
|
|
image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64
|
|
inner = result.get("result", result)
|
|
duration = inner.get("durationMs") or inner.get("duration_ms") if isinstance(inner, dict) else None
|
|
|
|
# Save raw image
|
|
safe_name = sprite_id.replace("/", "_") if sprite_id else f"v{variant_id}"
|
|
filename = f"{safe_name}_{variant_id}.png"
|
|
raw_path = self.raw_dir / filename
|
|
raw_path.write_bytes(image_bytes)
|
|
|
|
# Update variant as completed
|
|
self.registry.update_variant_status(
|
|
variant_id, "completed",
|
|
raw_path=str(raw_path),
|
|
generation_ms=duration,
|
|
)
|
|
|
|
print(f"[done] variant {variant_id} ({sprite_id}) -> {filename}")
|
|
return True
|
|
|
|
async def collect_pending(self, on_complete=None) -> int:
|
|
"""Collect ALL pending generation results via Redis pubsub.
|
|
|
|
Awaits all submitted variants concurrently. As each result arrives,
|
|
saves the image and optionally calls on_complete(variant_id, sprite_id).
|
|
|
|
Returns count of successfully collected variants.
|
|
"""
|
|
pending = self.registry.conn.execute(
|
|
"SELECT v.id, v.job_id, v.sprite_id FROM variants v "
|
|
"WHERE v.job_status = 'submitted' AND v.job_id IS NOT NULL"
|
|
).fetchall()
|
|
|
|
if not pending:
|
|
return 0
|
|
|
|
print(f"Awaiting {len(pending)} pending results via Redis pubsub...")
|
|
|
|
async def _collect_and_notify(row):
|
|
success = await self.collect_one(row["id"], row["job_id"], row["sprite_id"])
|
|
if success and on_complete:
|
|
await on_complete(row["id"], row["sprite_id"])
|
|
return success
|
|
|
|
tasks = [_collect_and_notify(row) for row in pending]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
return sum(1 for r in results if r is True)
|
|
|
|
async def close(self) -> None:
|
|
"""Clean up client resources."""
|
|
if hasattr(self._client, 'close'):
|
|
await self._client.close()
|