magicciv/tools/sprite-generation/engine/processor.py

197 lines
7.6 KiB
Python

"""Image post-processing for generated sprites — raw model output to game-ready assets."""
from __future__ import annotations
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw, ImageFilter
_CATEGORY_PROCESSORS = {
"terrain": "_process_terrain",
"biome_grid": "_process_biome_grid",
"edges": "_process_edges",
"units": "_process_units",
"buildings": "_process_buildings",
"resources": "_process_resources",
"improvements": "_process_improvements",
"spells": "_process_spells",
"ui": "_process_ui",
}
class SpriteProcessor:
def __init__(self, hex_mask_path: Path | None = None) -> None:
"""Load hex mask if provided. Mask is a 384x332 flat-top hex polygon."""
self.hex_mask: Image.Image | None = None
if hex_mask_path and hex_mask_path.exists():
self.hex_mask = Image.open(hex_mask_path).convert("L")
else:
self.hex_mask = self.generate_hex_mask()
def process(self, raw_path: Path, category: str, output_path: Path) -> bool:
"""Open raw image, apply category processing, save to output_path."""
method_name = _CATEGORY_PROCESSORS.get(category)
if method_name is None:
print(f"Unknown category '{category}', skipping {raw_path.name}")
return False
try:
img = Image.open(raw_path).convert("RGBA")
except Exception as exc:
print(f"Failed to open {raw_path}: {exc}")
return False
method = getattr(self, method_name)
result: Image.Image = method(img)
output_path.parent.mkdir(parents=True, exist_ok=True)
result.save(output_path)
return True
# -- category processors ---------------------------------------------------
def _process_terrain(self, img: Image.Image) -> Image.Image:
"""Center crop to square, resize to 384x332, apply hex mask."""
return self._apply_hex_mask(self._center_crop_square(img).resize((384, 332), Image.LANCZOS))
def _process_biome_grid(self, img: Image.Image) -> Image.Image:
"""Same as terrain — 384x332 hex-masked."""
return self._process_terrain(img)
def _process_edges(self, img: Image.Image) -> Image.Image:
"""384x332 with hex mask."""
return self._process_terrain(img)
def _process_units(self, img: Image.Image) -> Image.Image:
"""Remove green chroma key, center crop, resize to 256x256."""
img = self._remove_chroma_key(img)
return self._center_crop_square(img).resize((256, 256), Image.LANCZOS)
def _process_buildings(self, img: Image.Image) -> Image.Image:
"""Remove green chroma key, center crop, resize to 128x128."""
img = self._remove_chroma_key(img)
return self._center_crop_square(img).resize((128, 128), Image.LANCZOS)
def _process_resources(self, img: Image.Image) -> Image.Image:
"""Remove green chroma key, center crop, resize to 64x64."""
img = self._remove_chroma_key(img)
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
def _process_improvements(self, img: Image.Image) -> Image.Image:
"""Remove green chroma key, center crop, resize to 64x64."""
img = self._remove_chroma_key(img)
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
def _process_spells(self, img: Image.Image) -> Image.Image:
"""Center crop to square, resize to 128x128. No chroma key — spells use black bg."""
return self._center_crop_square(img).resize((128, 128), Image.LANCZOS)
def _process_ui(self, img: Image.Image) -> Image.Image:
"""Remove green chroma key, center crop, resize to 64x64."""
img = self._remove_chroma_key(img)
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
# -- hex mask --------------------------------------------------------------
def _apply_hex_mask(self, img: Image.Image) -> Image.Image:
"""Apply the hex mask as the alpha channel of the image."""
mask = self.hex_mask
if mask is None:
return img
if mask.size != img.size:
mask = mask.resize(img.size, Image.LANCZOS)
if img.mode != "RGBA":
img = img.convert("RGBA")
r, g, b, a = img.split()
# Combine existing alpha with hex mask
combined_alpha = Image.composite(a, Image.new("L", img.size, 0), mask)
return Image.merge("RGBA", (r, g, b, combined_alpha))
# -- chroma key removal ----------------------------------------------------
@staticmethod
def _remove_chroma_key(
img: Image.Image,
green_dominance: int = 30,
min_green: int = 80,
edge_feather: int = 3,
) -> Image.Image:
"""Remove green chroma key background, replacing with transparency.
Detects pixels where green channel is bright AND dominates R and B.
The two thresholds work together:
- green_dominance: G must exceed max(R,B) by this amount (filters out
dark greens, olive, etc. that appear in clothing/foliage subjects)
- min_green: G channel must be at least this bright (filters out
dark shadows that happen to be greenish)
Edge pixels are feathered with gaussian blur for anti-aliased edges.
"""
if img.mode != "RGBA":
img = img.convert("RGBA")
arr = np.array(img, dtype=np.float32)
r, g, b = arr[:, :, 0], arr[:, :, 1], arr[:, :, 2]
# Chroma key detection: bright green that dominates other channels
green_score = g - np.maximum(r, b)
is_chroma = (green_score > green_dominance) & (g > min_green)
# Create alpha mask: 0 for chroma pixels, 255 for subject
alpha = np.where(is_chroma, 0, 255).astype(np.uint8)
alpha_img = Image.fromarray(alpha, mode="L")
# Feather edges for smooth anti-aliasing
if edge_feather > 0:
alpha_img = alpha_img.filter(ImageFilter.GaussianBlur(radius=edge_feather))
# Boost contrast to keep hard interior while softening edges
alpha_arr = np.array(alpha_img, dtype=np.float32)
alpha_arr = np.clip(alpha_arr * 1.5, 0, 255).astype(np.uint8)
alpha_img = Image.fromarray(alpha_arr, mode="L")
# Combine with existing alpha channel
result = img.copy()
orig_alpha = np.array(img.split()[3])
new_alpha = np.minimum(orig_alpha, np.array(alpha_img))
result.putalpha(Image.fromarray(new_alpha, mode="L"))
return result
# -- utilities -------------------------------------------------------------
@staticmethod
def _center_crop_square(img: Image.Image) -> Image.Image:
"""Crop the center of the image to a square."""
w, h = img.size
side = min(w, h)
left = (w - side) // 2
top = (h - side) // 2
return img.crop((left, top, left + side, top + side))
@staticmethod
def generate_hex_mask(width: int = 384, height: int = 332) -> Image.Image:
"""Generate a flat-top hexagon mask programmatically.
Flat-top hex vertices for a hex of width W and height H,
centered at (W/2, H/2):
(W/4, 0), (3*W/4, 0), (W, H/2), (3*W/4, H), (W/4, H), (0, H/2)
Returns a grayscale Image where white=inside, black=outside.
"""
mask = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask)
w4 = width / 4
h2 = height / 2
vertices = [
(w4, 0),
(3 * w4, 0),
(width, h2),
(3 * w4, height),
(w4, height),
(0, h2),
]
draw.polygon(vertices, fill=255)
return mask