220 lines
7.5 KiB
Python
220 lines
7.5 KiB
Python
"""Sprite generation via model-boss InferenceClient.
|
||
|
||
Uses the same pattern as auto-commit-service: InferenceClient submits work
|
||
to model-boss coordinator via HTTP, waits for results via Redis pubsub.
|
||
No polling. model-boss queues internally and processes when GPU is available.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import json
|
||
import logging
|
||
import random
|
||
from pathlib import Path
|
||
|
||
from model_boss import InferenceClient
|
||
|
||
from engine.prompts import get_generation_size, get_negative, get_variant_modifier
|
||
from engine.registry import SpriteRegistry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Max concurrent generation requests (model-boss queues them, but this
|
||
# controls how many we're awaiting pubsub results for simultaneously)
|
||
MAX_CONCURRENT = 8
|
||
|
||
|
||
class SpriteGenerator:
|
||
def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path):
|
||
self.model = config["model"]
|
||
self.defaults = config["defaults"]
|
||
self.registry = registry
|
||
self.raw_dir = raw_dir
|
||
self.raw_dir.mkdir(parents=True, exist_ok=True)
|
||
self._client = InferenceClient(
|
||
client_id="sprite-generator",
|
||
default_priority="normal",
|
||
timeout=600.0,
|
||
)
|
||
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
||
|
||
async def generate_one(
|
||
self,
|
||
sprite_id: str,
|
||
prompt: str,
|
||
negative: str,
|
||
width: int,
|
||
height: int,
|
||
seed: int,
|
||
prompt_modifier: str = "",
|
||
priority: str = "normal",
|
||
dimension_id: int | None = None,
|
||
) -> int | None:
|
||
"""Generate a single sprite image via model-boss.
|
||
|
||
Submits to coordinator, waits for result via Redis pubsub.
|
||
Saves raw image to disk, creates variant in DB.
|
||
Returns variant_id or None on failure.
|
||
"""
|
||
full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p)
|
||
|
||
async with self._semaphore:
|
||
try:
|
||
result = await self._client.generate_image(
|
||
model=self.model,
|
||
prompt=full_prompt,
|
||
negative_prompt=negative,
|
||
width=width,
|
||
height=height,
|
||
steps=self.defaults.get("steps", 25),
|
||
guidance_scale=self.defaults.get("guidance_scale", 7.5),
|
||
seed=seed,
|
||
priority=priority,
|
||
keep_alive=300,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning("Generation failed for %s: %s", sprite_id, exc)
|
||
print(f"[error] {sprite_id}: {exc}")
|
||
return None
|
||
|
||
# Extract image data — format varies:
|
||
# model-boss InferenceClient: {"data": [{"b64_json": "..."}], "model": "..."}
|
||
# model-boss HTTP API: {"images": ["..."], "durationMs": N}
|
||
image_b64 = None
|
||
data_list = result.get("data", [])
|
||
if data_list and isinstance(data_list[0], dict):
|
||
image_b64 = data_list[0].get("b64_json")
|
||
if not image_b64:
|
||
images = result.get("images", [])
|
||
if images:
|
||
image_b64 = images[0]
|
||
if not image_b64:
|
||
print(f"[error] {sprite_id}: no image data in result (keys: {list(result.keys())})")
|
||
return None
|
||
|
||
image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64
|
||
duration = result.get("durationMs") or result.get("duration_ms")
|
||
|
||
# Create variant in DB
|
||
variant_id = self.registry.add_variant(
|
||
sprite_id=sprite_id,
|
||
seed=seed,
|
||
dimension_id=dimension_id,
|
||
prompt_modifier=prompt_modifier,
|
||
job_id=None,
|
||
)
|
||
|
||
# Save raw image
|
||
safe_name = sprite_id.replace("/", "_")
|
||
filename = f"{safe_name}_{variant_id}.png"
|
||
raw_path = self.raw_dir / filename
|
||
raw_path.write_bytes(image_bytes)
|
||
|
||
# Update variant as completed
|
||
self.registry.update_variant_status(
|
||
variant_id, "completed",
|
||
raw_path=str(raw_path),
|
||
generation_ms=duration,
|
||
)
|
||
|
||
return variant_id
|
||
|
||
async def generate_batch(
|
||
self,
|
||
sprite_ids: list[str],
|
||
variants_per: int = 4,
|
||
priority: str = "normal",
|
||
) -> int:
|
||
"""Generate variants for a list of sprites concurrently.
|
||
|
||
Each request goes through model-boss → Redis pubsub.
|
||
Concurrency controlled by semaphore (MAX_CONCURRENT awaiting at once).
|
||
model-boss handles its own internal queue for GPU scheduling.
|
||
|
||
Returns total variants generated.
|
||
"""
|
||
run_id = self.registry.start_run(variants_per=variants_per)
|
||
tasks: list[asyncio.Task] = []
|
||
total_planned = 0
|
||
|
||
for sprite_id in sprite_ids:
|
||
sprite = self.registry.get_sprite(sprite_id)
|
||
if not sprite:
|
||
continue
|
||
|
||
prompt = sprite["prompt"]
|
||
negative = sprite["negative_prompt"] or get_negative(sprite["category"])
|
||
gen_w, gen_h = get_generation_size(sprite["category"])
|
||
width = sprite.get("gen_width") or gen_w
|
||
height = sprite.get("gen_height") or gen_h
|
||
|
||
for i in range(variants_per):
|
||
seed = random.randint(0, 2**32 - 1)
|
||
modifier = get_variant_modifier(i)
|
||
total_planned += 1
|
||
|
||
task = asyncio.create_task(
|
||
self._generate_and_report(
|
||
sprite_id=sprite_id,
|
||
prompt=prompt,
|
||
negative=negative,
|
||
width=width,
|
||
height=height,
|
||
seed=seed,
|
||
prompt_modifier=modifier,
|
||
priority=priority,
|
||
index=total_planned,
|
||
)
|
||
)
|
||
tasks.append(task)
|
||
|
||
print(f"Queued {total_planned} generation requests ({len(sprite_ids)} sprites × {variants_per} variants)")
|
||
|
||
# Update run with total
|
||
with self.registry.conn:
|
||
self.registry.conn.execute(
|
||
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
|
||
(total_planned, run_id),
|
||
)
|
||
|
||
# Await all — results stream in as model-boss completes them
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
completed = sum(1 for r in results if isinstance(r, int) and r is not None)
|
||
failed = total_planned - completed
|
||
self.registry.update_run(run_id, completed_delta=completed, failed_delta=failed, finished=True)
|
||
|
||
print(f"\nCompleted: {completed}/{total_planned} ({failed} failed)")
|
||
return completed
|
||
|
||
async def _generate_and_report(
|
||
self,
|
||
sprite_id: str,
|
||
prompt: str,
|
||
negative: str,
|
||
width: int,
|
||
height: int,
|
||
seed: int,
|
||
prompt_modifier: str,
|
||
priority: str,
|
||
index: int,
|
||
) -> int | None:
|
||
"""Generate one variant and print progress."""
|
||
variant_id = await self.generate_one(
|
||
sprite_id=sprite_id,
|
||
prompt=prompt,
|
||
negative=negative,
|
||
width=width,
|
||
height=height,
|
||
seed=seed,
|
||
prompt_modifier=prompt_modifier,
|
||
priority=priority,
|
||
)
|
||
if variant_id is not None:
|
||
print(f"[{index}] Done {sprite_id} (seed={seed}) → variant {variant_id}")
|
||
return variant_id
|
||
|
||
async def close(self) -> None:
|
||
"""Clean up client resources."""
|
||
await self._client.close()
|