magicciv/tools/sprite-generation/engine/processor.py

197 lines
7.2 KiB
Python
Raw Normal View History

"""Image post-processing for generated sprites — raw model output to game-ready assets."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
from PIL import Image, ImageDraw
logger = logging.getLogger(__name__)
_CATEGORY_PROCESSORS = {
"terrain": "_process_terrain",
"biome_grid": "_process_biome_grid",
"edges": "_process_edges",
"units": "_process_units",
"buildings": "_process_buildings",
"resources": "_process_resources",
"improvements": "_process_improvements",
"spells": "_process_spells",
"ui": "_process_ui",
}
class SpriteProcessor:
def __init__(self, hex_mask_path: Path | None = None) -> None:
"""Load hex mask if provided. Mask is a 384x332 flat-top hex polygon."""
self.hex_mask: Image.Image | None = None
if hex_mask_path and hex_mask_path.exists():
self.hex_mask = Image.open(hex_mask_path).convert("L")
else:
self.hex_mask = self.generate_hex_mask()
# Lazy-loaded rembg session for background removal
self._rembg_session: Optional[object] = None
def process(self, raw_path: Path, category: str, output_path: Path) -> bool:
"""Open raw image, apply category processing, save to output_path."""
method_name = _CATEGORY_PROCESSORS.get(category)
if method_name is None:
print(f"Unknown category '{category}', skipping {raw_path.name}")
return False
try:
img = Image.open(raw_path).convert("RGBA")
except Exception as exc:
print(f"Failed to open {raw_path}: {exc}")
return False
method = getattr(self, method_name)
result: Image.Image = method(img)
output_path.parent.mkdir(parents=True, exist_ok=True)
result.save(output_path)
return True
# -- category processors ---------------------------------------------------
def _process_terrain(self, img: Image.Image) -> Image.Image:
"""Center crop to square, resize to 384x332, apply hex mask."""
return self._apply_hex_mask(self._center_crop_square(img).resize((384, 332), Image.LANCZOS))
def _process_biome_grid(self, img: Image.Image) -> Image.Image:
"""Same as terrain — 384x332 hex-masked."""
return self._process_terrain(img)
def _process_edges(self, img: Image.Image) -> Image.Image:
"""384x332 with hex mask."""
return self._process_terrain(img)
def _process_units(self, img: Image.Image) -> Image.Image:
"""Remove background, center crop, resize to 256x256."""
img = self._remove_background(img)
return self._center_crop_square(img).resize((256, 256), Image.LANCZOS)
def _process_buildings(self, img: Image.Image) -> Image.Image:
"""Remove background, center crop, resize to 128x128."""
img = self._remove_background(img)
return self._center_crop_square(img).resize((128, 128), Image.LANCZOS)
def _process_resources(self, img: Image.Image) -> Image.Image:
"""Remove background, center crop, resize to 64x64."""
img = self._remove_background(img)
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
def _process_improvements(self, img: Image.Image) -> Image.Image:
"""Remove background, center crop, resize to 64x64."""
img = self._remove_background(img)
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
def _process_spells(self, img: Image.Image) -> Image.Image:
"""Remove background, center crop, resize to 128x128."""
img = self._remove_background(img)
return self._center_crop_square(img).resize((128, 128), Image.LANCZOS)
def _process_ui(self, img: Image.Image) -> Image.Image:
"""Remove background, center crop, resize to 64x64."""
img = self._remove_background(img)
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
# -- hex mask --------------------------------------------------------------
def _apply_hex_mask(self, img: Image.Image) -> Image.Image:
"""Apply the hex mask as the alpha channel of the image."""
mask = self.hex_mask
if mask is None:
return img
if mask.size != img.size:
mask = mask.resize(img.size, Image.LANCZOS)
if img.mode != "RGBA":
img = img.convert("RGBA")
r, g, b, a = img.split()
# Combine existing alpha with hex mask
combined_alpha = Image.composite(a, Image.new("L", img.size, 0), mask)
return Image.merge("RGBA", (r, g, b, combined_alpha))
# -- background removal (rembg / U2Net) ------------------------------------
def _get_rembg_session(self) -> object:
"""Lazy-load rembg session on first use.
Uses U2Net with CPU provider to avoid CUDA dtype conflicts the
diffusion pipeline uses fp16 but rembg/onnxruntime expects fp32.
CPU is fast enough for the small images we process (~1024x1024).
"""
if self._rembg_session is not None:
return self._rembg_session
from rembg import new_session
self._rembg_session = new_session(
"u2net",
providers=["CPUExecutionProvider"],
)
logger.info("rembg session initialized (u2net, CPU)")
return self._rembg_session
def _remove_background(self, img: Image.Image) -> Image.Image:
"""Remove background using U2Net neural segmentation via rembg.
Unlike chroma keying, this works regardless of background or subject
color it segments the salient foreground object from everything else.
Alpha matting is enabled for cleaner edges on small game sprites.
"""
from rembg import remove
if img.mode != "RGB":
img = img.convert("RGB")
return remove(
img,
session=self._get_rembg_session(),
alpha_matting=True,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
)
# -- utilities -------------------------------------------------------------
@staticmethod
def _center_crop_square(img: Image.Image) -> Image.Image:
"""Crop the center of the image to a square."""
w, h = img.size
side = min(w, h)
left = (w - side) // 2
top = (h - side) // 2
return img.crop((left, top, left + side, top + side))
@staticmethod
def generate_hex_mask(width: int = 384, height: int = 332) -> Image.Image:
"""Generate a flat-top hexagon mask programmatically.
Flat-top hex vertices for a hex of width W and height H,
centered at (W/2, H/2):
(W/4, 0), (3*W/4, 0), (W, H/2), (3*W/4, H), (W/4, H), (0, H/2)
Returns a grayscale Image where white=inside, black=outside.
"""
mask = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask)
w4 = width / 4
h2 = height / 2
vertices = [
(w4, 0),
(3 * w4, 0),
(width, h2),
(3 * w4, height),
(w4, height),
(0, h2),
]
draw.polygon(vertices, fill=255)
return mask