fix(generate_image): enhance image handling and output path resolution in generate_image script

This commit is contained in:
nblog
2026-02-09 21:09:09 +08:00
parent b8bbc75db2
commit 79c34297fa

View File

@@ -15,6 +15,8 @@ import mimetypes
import os
from pathlib import Path
from openai import OpenAI
# Configuration
MAX_INPUT_IMAGES = 3
@@ -56,7 +58,7 @@ def require_api_key():
def encode_image_to_data_url(path: Path) -> str:
if not path.exists():
raise SystemExit(f"Input image not found: {path}")
mime, _ = mimetypes.guess_type(path.name)
mime, _ = mimetypes.guess_type(str(path))
if not mime:
mime = "image/png"
data = path.read_bytes()
@@ -64,16 +66,17 @@ def encode_image_to_data_url(path: Path) -> str:
return f"data:{mime};base64,{encoded}"
def build_message_content(prompt: str, input_images):
content = [{"type": "text", "text": prompt}]
def build_message_content(prompt: str, input_images: list[str]) -> list[dict]:
content: list[dict] = [{"type": "text", "text": prompt}]
for image_path in input_images:
data_url = encode_image_to_data_url(Path(image_path))
content.append({"type": "image_url", "image_url": {"url": data_url}})
return content
def parse_data_url(data_url: str):
def parse_data_url(data_url: str) -> tuple[str, bytes]:
if not data_url.startswith("data:") or ";base64," not in data_url:
raise ValueError("Image URL is not a base64 data URL.")
raise SystemExit("Image URL is not a base64 data URL.")
header, encoded = data_url.split(",", 1)
mime = header[5:].split(";", 1)[0]
try:
@@ -83,35 +86,27 @@ def parse_data_url(data_url: str):
return mime, raw
def resolve_output_paths(filename: str, image_count: int, mime: str):
output_path = Path(filename)
suffix = output_path.suffix
def resolve_output_path(filename: str, image_index: int, total_count: int, mime: str) -> Path:
output_path = Path(filename)
suffix = output_path.suffix
# Validate/correct suffix matches MIME type
expected_suffix = MIME_TO_EXT.get(mime, ".png")
if suffix and suffix.lower() != expected_suffix.lower():
print(f"Warning: filename extension '{suffix}' doesn't match returned MIME type '{mime}'. Using '{expected_suffix}' instead.")
suffix = expected_suffix
elif not suffix:
suffix = expected_suffix
# Validate/correct suffix matches MIME type
expected_suffix = MIME_TO_EXT.get(mime, ".png")
if suffix and suffix.lower() != expected_suffix.lower():
print(f"Warning: filename extension '{suffix}' doesn't match returned MIME type '{mime}'. Using '{expected_suffix}' instead.")
suffix = expected_suffix
elif not suffix:
suffix = expected_suffix
output_path = output_path.with_suffix(suffix)
# Single image: use original stem + corrected suffix
if total_count <= 1:
return output_path.with_suffix(suffix)
# Create parent directory if it doesn't exist (for paths with parent directories, absolute or relative)
if output_path.parent and str(output_path.parent) != '.':
output_path.parent.mkdir(parents=True, exist_ok=True)
if image_count == 1:
return [output_path]
paths = []
for index in range(image_count):
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
paths.append(numbered)
return paths
# Multiple images: append numbering
return output_path.with_name(f"{output_path.stem}-{image_index + 1}{suffix}")
def extract_image_url(image):
def extract_image_url(image: dict | object) -> str | None:
if isinstance(image, dict):
return image.get("image_url", {}).get("url") or image.get("url")
return None
@@ -123,7 +118,7 @@ def load_system_prompt():
template_path = script_dir / "assets" / "SYSTEM_TEMPLATE"
if template_path.exists():
content = template_path.read_text().strip()
content = template_path.read_text(encoding="utf-8").strip()
if content:
return content
return None
@@ -135,9 +130,8 @@ def main():
if len(args.input_image) > MAX_INPUT_IMAGES:
raise SystemExit(f"Too many input images: {len(args.input_image)} (max {MAX_INPUT_IMAGES}).")
image_size = args.resolution or "1K"
image_size = args.resolution
from openai import OpenAI
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=require_api_key())
# Build messages with optional system prompt
@@ -173,19 +167,18 @@ def main():
if not images:
raise SystemExit("No images returned by the API.")
first_url = extract_image_url(images[0])
if not first_url:
raise SystemExit("Image payload missing image_url.url.")
first_mime, _ = parse_data_url(first_url)
output_paths = resolve_output_paths(args.filename, len(images), first_mime)
# Create output directory once before processing images
output_base_path = Path(args.filename)
if output_base_path.parent and str(output_base_path.parent) != '.':
output_base_path.parent.mkdir(parents=True, exist_ok=True)
saved_paths = []
for idx, image in enumerate(images):
image_url = extract_image_url(image)
if not image_url:
raise SystemExit("Image payload missing image_url.url.")
_, raw = parse_data_url(image_url)
output_path = output_paths[idx]
mime, raw = parse_data_url(image_url)
output_path = resolve_output_path(args.filename, idx, len(images), mime)
output_path.write_bytes(raw)
saved_paths.append(output_path.resolve())