416 lines
16 KiB
Python
416 lines
16 KiB
Python
"""Sprite generation via model-boss queue + Redis pubsub.
|
|
|
|
Two-phase architecture:
|
|
1. SUBMIT — queue generation requests to model-boss (returns immediately)
|
|
2. COLLECT — await results via Redis pubsub as GPU completes them
|
|
|
|
The CLI can submit hundreds of requests in seconds. model-boss queues them
|
|
internally and processes when GPU/VRAM is available. Results arrive via
|
|
Redis pubsub — no polling, no blocking HTTP connections.
|
|
|
|
Request IDs are stored in the variants table (job_id column) so pending
|
|
work survives process restarts.
|
|
|
|
## Variant job_status lifecycle:
|
|
submitted → completed (image received and saved)
|
|
submitted → failed (generation error or timeout)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import math
|
|
import random
|
|
from pathlib import Path
|
|
|
|
from model_boss import InferenceClient
|
|
|
|
from engine.prompts import (
|
|
get_generation_size, get_negative, get_variant_modifier,
|
|
compose_prompt, _extract_unit_class, _unit_classes,
|
|
)
|
|
from engine.registry import SpriteRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _extract_image_b64(result: dict) -> str | None:
|
|
"""Extract base64 image data from model-boss result.
|
|
|
|
Handles both response formats:
|
|
- InferenceClient queue: {"result": {"data": [{"b64_json": "..."}]}}
|
|
- Direct HTTP API: {"images": ["..."], "durationMs": N}
|
|
"""
|
|
# Queue result wraps in "result" key
|
|
inner = result.get("result", result)
|
|
if isinstance(inner, dict):
|
|
data_list = inner.get("data", [])
|
|
if data_list and isinstance(data_list[0], dict):
|
|
b64 = data_list[0].get("b64_json")
|
|
if b64:
|
|
return b64
|
|
images = inner.get("images", [])
|
|
if images:
|
|
return images[0]
|
|
return None
|
|
|
|
|
|
class SpriteGenerator:
|
|
def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path):
|
|
self.model = config["model"]
|
|
self.defaults = config["defaults"]
|
|
self.registry = registry
|
|
self.raw_dir = raw_dir
|
|
self.raw_dir.mkdir(parents=True, exist_ok=True)
|
|
self._client = InferenceClient(
|
|
client_id="sprite-generator",
|
|
default_priority="normal",
|
|
timeout=600.0,
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Prompt helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _compose_fresh(category: str, entity_id: str) -> tuple[str, str]:
|
|
"""Re-compose prompt + negative from current YAML for a sprite.
|
|
|
|
Rebuilds entity_data from entity_id at submission time so YAML changes
|
|
take effect without requiring a rescan. For units, extracts race, gender,
|
|
and combat_type from the entity_id (e.g. "archers_dwarves_m").
|
|
|
|
Returns (prompt, negative) or ("", "") if reconstruction fails.
|
|
"""
|
|
if not entity_id:
|
|
return "", ""
|
|
|
|
if category == "units":
|
|
# Extract unit class → combat_type
|
|
unit_class = _extract_unit_class(entity_id)
|
|
combat_type = ""
|
|
if unit_class:
|
|
combat_type = _unit_classes().get(unit_class, {}).get("combat_type", "")
|
|
|
|
# Parse race from entity_id (e.g. "_dwarves")
|
|
race = None
|
|
for r in ("dwarves", "humans", "high_elves", "orcs"):
|
|
if f"_{r}" in entity_id:
|
|
race = r
|
|
break
|
|
|
|
# Parse gender from entity_id
|
|
gender = None
|
|
if entity_id.endswith("_m"):
|
|
gender = "m"
|
|
elif entity_id.endswith("_f"):
|
|
gender = "f"
|
|
|
|
entity_data = {
|
|
"entity_id": entity_id,
|
|
"combat_type": combat_type,
|
|
"description": "",
|
|
"keywords": [],
|
|
}
|
|
dimensions = {
|
|
k: v for k, v in {
|
|
"race": race,
|
|
"gender": gender,
|
|
"quality": 2, # default mid-tier for prompt composition
|
|
}.items() if v is not None
|
|
}
|
|
prompt = compose_prompt(category, entity_data, dimensions)
|
|
negative = get_negative(category, combat_type=combat_type, race=race or "")
|
|
return prompt, negative
|
|
|
|
# Non-units: use entity_id as base description
|
|
entity_data = {"entity_id": entity_id, "name": entity_id, "description": "", "keywords": []}
|
|
prompt = compose_prompt(category, entity_data)
|
|
negative = get_negative(category)
|
|
return prompt, negative
|
|
|
|
# ------------------------------------------------------------------
|
|
# Seed + modifier selection
|
|
# ------------------------------------------------------------------
|
|
|
|
def _select_seeds(self, category: str, entity_id: str, n: int) -> list[int]:
|
|
"""Select n seeds using a 70/30 proven/random split.
|
|
|
|
Proven seeds come from the seed_pool table (accumulated from past
|
|
high-scoring variants). Falls back to fully random when pool is empty
|
|
or has too few entries to fill the proven quota.
|
|
|
|
Also explores seed neighbors (±1..3 from the best proven seed) to
|
|
sample nearby latent-space regions systematically.
|
|
"""
|
|
proven = self.registry.get_proven_seeds(
|
|
category=category,
|
|
entity_id=entity_id,
|
|
limit=40,
|
|
min_quality=65.0,
|
|
)
|
|
proven_target = math.ceil(n * 0.7)
|
|
selected: list[int] = []
|
|
|
|
if proven:
|
|
weights = [p["avg_quality"] for p in proven]
|
|
total_w = sum(weights)
|
|
norm = [w / total_w for w in weights]
|
|
chosen = random.choices(proven, weights=norm, k=min(proven_target, len(proven)))
|
|
selected = [p["seed"] for p in chosen]
|
|
|
|
# Neighbor exploration: ±1..3 from best seed
|
|
best_seed = max(proven, key=lambda p: p["avg_quality"])["seed"]
|
|
for delta in range(1, 4):
|
|
if len(selected) < proven_target:
|
|
selected.append((best_seed + delta) % (2**32))
|
|
|
|
while len(selected) < n:
|
|
selected.append(random.randint(0, 2**32 - 1))
|
|
|
|
return selected[:n]
|
|
|
|
def _select_modifiers(self, entity_id: str, category: str, n: int) -> list[int]:
|
|
"""Return modifier indices for n variants.
|
|
|
|
When generation_hints has ≥20 samples, weights 60% toward historically
|
|
passing modifier indices. Otherwise uses the standard sequential cycle.
|
|
"""
|
|
hints = self.registry.get_generation_hints(entity_id, category)
|
|
if hints and hints.get("best_modifier_indices") and hints["sample_count"] >= 20:
|
|
good_mods: list[int] = json.loads(hints["best_modifier_indices"])
|
|
if good_mods:
|
|
good_count = round(n * 0.6)
|
|
result = [random.choice(good_mods) for _ in range(good_count)]
|
|
result += list(range(n - good_count))
|
|
return result
|
|
return list(range(n))
|
|
|
|
# ------------------------------------------------------------------
|
|
# Phase 1: SUBMIT — queue requests, return immediately
|
|
# ------------------------------------------------------------------
|
|
|
|
async def submit_one(
|
|
self,
|
|
sprite_id: str,
|
|
prompt: str,
|
|
negative: str,
|
|
width: int,
|
|
height: int,
|
|
seed: int,
|
|
prompt_modifier: str = "",
|
|
priority: str = "normal",
|
|
dimension_id: int | None = None,
|
|
guidance_scale: float | None = None,
|
|
) -> tuple[int, str] | None:
|
|
"""Submit a single generation request to model-boss queue.
|
|
|
|
Creates variant in DB with job_status='submitted' and stores request_id.
|
|
Returns (variant_id, request_id) or None on submission failure.
|
|
"""
|
|
full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p)
|
|
|
|
effective_guidance = (
|
|
guidance_scale if guidance_scale is not None
|
|
else self.defaults.get("guidance_scale", 7.5)
|
|
)
|
|
body_fields = {
|
|
"prompt": full_prompt,
|
|
"negative_prompt": negative,
|
|
"width": width,
|
|
"height": height,
|
|
"steps": self.defaults.get("steps", 25),
|
|
"guidance_scale": effective_guidance,
|
|
"seed": seed,
|
|
"n": 1,
|
|
}
|
|
|
|
try:
|
|
request_id = await self._client.submit(
|
|
model=self.model,
|
|
messages=[], # unused for image generation — body fields carry the payload
|
|
endpoint="generate-image",
|
|
priority=priority,
|
|
keep_alive=300,
|
|
context={
|
|
"label": sprite_id,
|
|
"description": f"seed={seed}",
|
|
"tags": ["sprite-gen"],
|
|
},
|
|
**body_fields,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("Submit failed for %s: %s: %r", sprite_id, type(exc).__name__, exc)
|
|
print(f"[submit-error] {sprite_id}: {type(exc).__name__}: {exc!r}")
|
|
return None
|
|
|
|
# Create variant in DB — submitted, awaiting result
|
|
variant_id = self.registry.add_variant(
|
|
sprite_id=sprite_id,
|
|
seed=seed,
|
|
dimension_id=dimension_id,
|
|
prompt_modifier=prompt_modifier,
|
|
job_id=request_id,
|
|
model=self.model,
|
|
prompt_used=full_prompt,
|
|
negative_used=negative,
|
|
guidance_scale=effective_guidance,
|
|
steps=self.defaults.get("steps", 25),
|
|
prompt_author="claude-opus-4-6",
|
|
)
|
|
|
|
return variant_id, request_id
|
|
|
|
async def submit_batch(
|
|
self,
|
|
sprite_ids: list[str],
|
|
variants_per: int = 4,
|
|
priority: str = "normal",
|
|
) -> int:
|
|
"""Submit generation requests for all sprites. Returns immediately.
|
|
|
|
Variant rows created in DB with job_status='submitted'.
|
|
Use collect_pending() to retrieve results.
|
|
"""
|
|
run_id = self.registry.start_run(variants_per=variants_per)
|
|
submitted = 0
|
|
|
|
for sprite_id in sprite_ids:
|
|
sprite = self.registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
continue
|
|
|
|
category = sprite["category"]
|
|
entity_id = sprite.get("entity_id", "")
|
|
gen_w, gen_h = get_generation_size(category)
|
|
|
|
# Re-compose prompt from current YAML every submission — never use stale DB prompt.
|
|
# Sprite records store prompts from scan time; YAML changes need live recomposition.
|
|
prompt, negative = self._compose_fresh(category, entity_id)
|
|
if not prompt:
|
|
prompt = sprite["prompt"] or ""
|
|
negative = sprite["negative_prompt"] or get_negative(category)
|
|
width = sprite.get("gen_width") or gen_w
|
|
height = sprite.get("gen_height") or gen_h
|
|
|
|
seeds = self._select_seeds(category, entity_id, variants_per)
|
|
modifier_indices = self._select_modifiers(entity_id, category, variants_per)
|
|
|
|
# Adaptive guidance: use generation_hints when ≥10 passing samples exist
|
|
hints = self.registry.get_generation_hints(entity_id, category)
|
|
if hints and hints.get("best_guidance") and hints["sample_count"] >= 10:
|
|
guidance = max(6.5, min(9.0, hints["best_guidance"]))
|
|
else:
|
|
guidance = self.defaults.get("guidance_scale", 7.5)
|
|
|
|
for i in range(variants_per):
|
|
seed = seeds[i]
|
|
modifier = get_variant_modifier(modifier_indices[i])
|
|
|
|
result = await self.submit_one(
|
|
sprite_id=sprite_id,
|
|
prompt=prompt,
|
|
negative=negative,
|
|
width=width,
|
|
height=height,
|
|
seed=seed,
|
|
prompt_modifier=modifier,
|
|
priority=priority,
|
|
guidance_scale=guidance,
|
|
)
|
|
if result:
|
|
submitted += 1
|
|
|
|
# Update run with total
|
|
with self.registry.conn:
|
|
self.registry.conn.execute(
|
|
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
|
|
(submitted, run_id),
|
|
)
|
|
|
|
print(f"Queued {submitted} requests ({len(sprite_ids)} sprites x {variants_per} variants)")
|
|
return submitted
|
|
|
|
# ------------------------------------------------------------------
|
|
# Phase 2: COLLECT — await results via Redis pubsub
|
|
# ------------------------------------------------------------------
|
|
|
|
async def collect_one(self, variant_id: int, request_id: str, sprite_id: str = "") -> bool:
|
|
"""Wait for one generation result via Redis pubsub. Save image to disk.
|
|
|
|
Returns True on success, False on failure.
|
|
"""
|
|
try:
|
|
result = await self._client.wait_for_result(request_id, timeout=600)
|
|
except Exception as exc:
|
|
logger.warning("Collect failed for variant %d: %s", variant_id, exc)
|
|
self.registry.update_variant_status(variant_id, "failed")
|
|
print(f"[fail] variant {variant_id}: {exc}")
|
|
return False
|
|
|
|
if result.get("status") == "failed":
|
|
error = result.get("error", "unknown error")
|
|
self.registry.update_variant_status(variant_id, "failed")
|
|
print(f"[fail] variant {variant_id}: {error}")
|
|
return False
|
|
|
|
image_b64 = _extract_image_b64(result)
|
|
if not image_b64:
|
|
print(f"[fail] variant {variant_id}: no image data (keys: {list(result.keys())})")
|
|
self.registry.update_variant_status(variant_id, "failed")
|
|
return False
|
|
|
|
image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64
|
|
inner = result.get("result", result)
|
|
duration = inner.get("durationMs") or inner.get("duration_ms") if isinstance(inner, dict) else None
|
|
|
|
# Save raw image
|
|
safe_name = sprite_id.replace("/", "_") if sprite_id else f"v{variant_id}"
|
|
filename = f"{safe_name}_{variant_id}.png"
|
|
raw_path = self.raw_dir / filename
|
|
raw_path.write_bytes(image_bytes)
|
|
|
|
# Update variant as completed
|
|
self.registry.update_variant_status(
|
|
variant_id, "completed",
|
|
raw_path=str(raw_path),
|
|
generation_ms=duration,
|
|
)
|
|
|
|
print(f"[done] variant {variant_id} ({sprite_id}) -> {filename}")
|
|
return True
|
|
|
|
async def collect_pending(self, on_complete=None) -> int:
|
|
"""Collect ALL pending generation results via Redis pubsub.
|
|
|
|
Awaits all submitted variants concurrently. As each result arrives,
|
|
saves the image and optionally calls on_complete(variant_id, sprite_id).
|
|
|
|
Returns count of successfully collected variants.
|
|
"""
|
|
pending = self.registry.conn.execute(
|
|
"SELECT v.id, v.job_id, v.sprite_id FROM variants v "
|
|
"WHERE v.job_status = 'submitted' AND v.job_id IS NOT NULL"
|
|
).fetchall()
|
|
|
|
if not pending:
|
|
return 0
|
|
|
|
print(f"Awaiting {len(pending)} pending results via Redis pubsub...")
|
|
|
|
async def _collect_and_notify(row):
|
|
success = await self.collect_one(row["id"], row["job_id"], row["sprite_id"])
|
|
if success and on_complete:
|
|
await on_complete(row["id"], row["sprite_id"])
|
|
return success
|
|
|
|
tasks = [_collect_and_notify(row) for row in pending]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
return sum(1 for r in results if r is True)
|
|
|
|
async def close(self) -> None:
|
|
"""Clean up client resources."""
|
|
if hasattr(self._client, 'close'):
|
|
await self._client.close()
|