magicciv/tools/sprite-generation/engine/generator.py
2026-03-26 01:06:57 -07:00

220 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Sprite generation via model-boss InferenceClient.
Uses the same pattern as auto-commit-service: InferenceClient submits work
to model-boss coordinator via HTTP, waits for results via Redis pubsub.
No polling. model-boss queues internally and processes when GPU is available.
"""
from __future__ import annotations
import asyncio
import base64
import json
import logging
import random
from pathlib import Path
from model_boss import InferenceClient
from engine.prompts import get_generation_size, get_negative, get_variant_modifier
from engine.registry import SpriteRegistry
logger = logging.getLogger(__name__)
# Max concurrent generation requests (model-boss queues them, but this
# controls how many we're awaiting pubsub results for simultaneously)
MAX_CONCURRENT = 8
class SpriteGenerator:
def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path):
self.model = config["model"]
self.defaults = config["defaults"]
self.registry = registry
self.raw_dir = raw_dir
self.raw_dir.mkdir(parents=True, exist_ok=True)
self._client = InferenceClient(
client_id="sprite-generator",
default_priority="normal",
timeout=600.0,
)
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT)
async def generate_one(
self,
sprite_id: str,
prompt: str,
negative: str,
width: int,
height: int,
seed: int,
prompt_modifier: str = "",
priority: str = "normal",
dimension_id: int | None = None,
) -> int | None:
"""Generate a single sprite image via model-boss.
Submits to coordinator, waits for result via Redis pubsub.
Saves raw image to disk, creates variant in DB.
Returns variant_id or None on failure.
"""
full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p)
async with self._semaphore:
try:
result = await self._client.generate_image(
model=self.model,
prompt=full_prompt,
negative_prompt=negative,
width=width,
height=height,
steps=self.defaults.get("steps", 25),
guidance_scale=self.defaults.get("guidance_scale", 7.5),
seed=seed,
priority=priority,
keep_alive=300,
)
except Exception as exc:
logger.warning("Generation failed for %s: %s", sprite_id, exc)
print(f"[error] {sprite_id}: {exc}")
return None
# Extract image data — format varies:
# model-boss InferenceClient: {"data": [{"b64_json": "..."}], "model": "..."}
# model-boss HTTP API: {"images": ["..."], "durationMs": N}
image_b64 = None
data_list = result.get("data", [])
if data_list and isinstance(data_list[0], dict):
image_b64 = data_list[0].get("b64_json")
if not image_b64:
images = result.get("images", [])
if images:
image_b64 = images[0]
if not image_b64:
print(f"[error] {sprite_id}: no image data in result (keys: {list(result.keys())})")
return None
image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64
duration = result.get("durationMs") or result.get("duration_ms")
# Create variant in DB
variant_id = self.registry.add_variant(
sprite_id=sprite_id,
seed=seed,
dimension_id=dimension_id,
prompt_modifier=prompt_modifier,
job_id=None,
)
# Save raw image
safe_name = sprite_id.replace("/", "_")
filename = f"{safe_name}_{variant_id}.png"
raw_path = self.raw_dir / filename
raw_path.write_bytes(image_bytes)
# Update variant as completed
self.registry.update_variant_status(
variant_id, "completed",
raw_path=str(raw_path),
generation_ms=duration,
)
return variant_id
async def generate_batch(
self,
sprite_ids: list[str],
variants_per: int = 4,
priority: str = "normal",
) -> int:
"""Generate variants for a list of sprites concurrently.
Each request goes through model-boss → Redis pubsub.
Concurrency controlled by semaphore (MAX_CONCURRENT awaiting at once).
model-boss handles its own internal queue for GPU scheduling.
Returns total variants generated.
"""
run_id = self.registry.start_run(variants_per=variants_per)
tasks: list[asyncio.Task] = []
total_planned = 0
for sprite_id in sprite_ids:
sprite = self.registry.get_sprite(sprite_id)
if not sprite:
continue
prompt = sprite["prompt"]
negative = sprite["negative_prompt"] or get_negative(sprite["category"])
gen_w, gen_h = get_generation_size(sprite["category"])
width = sprite.get("gen_width") or gen_w
height = sprite.get("gen_height") or gen_h
for i in range(variants_per):
seed = random.randint(0, 2**32 - 1)
modifier = get_variant_modifier(i)
total_planned += 1
task = asyncio.create_task(
self._generate_and_report(
sprite_id=sprite_id,
prompt=prompt,
negative=negative,
width=width,
height=height,
seed=seed,
prompt_modifier=modifier,
priority=priority,
index=total_planned,
)
)
tasks.append(task)
print(f"Queued {total_planned} generation requests ({len(sprite_ids)} sprites × {variants_per} variants)")
# Update run with total
with self.registry.conn:
self.registry.conn.execute(
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
(total_planned, run_id),
)
# Await all — results stream in as model-boss completes them
results = await asyncio.gather(*tasks, return_exceptions=True)
completed = sum(1 for r in results if isinstance(r, int) and r is not None)
failed = total_planned - completed
self.registry.update_run(run_id, completed_delta=completed, failed_delta=failed, finished=True)
print(f"\nCompleted: {completed}/{total_planned} ({failed} failed)")
return completed
async def _generate_and_report(
self,
sprite_id: str,
prompt: str,
negative: str,
width: int,
height: int,
seed: int,
prompt_modifier: str,
priority: str,
index: int,
) -> int | None:
"""Generate one variant and print progress."""
variant_id = await self.generate_one(
sprite_id=sprite_id,
prompt=prompt,
negative=negative,
width=width,
height=height,
seed=seed,
prompt_modifier=prompt_modifier,
priority=priority,
)
if variant_id is not None:
print(f"[{index}] Done {sprite_id} (seed={seed}) → variant {variant_id}")
return variant_id
async def close(self) -> None:
"""Clean up client resources."""
await self._client.close()