fix(generate_image): improve input image handling and validate output filename extensions

This commit is contained in:
nblog
2026-02-09 17:24:24 +08:00
parent 5efb7329a3
commit b8bbc75db2

View File

@@ -38,10 +38,10 @@ def parse_args():
help="Output resolution: 1K, 2K, or 4K.", help="Output resolution: 1K, 2K, or 4K.",
) )
parser.add_argument( parser.add_argument(
"--input-image", "--input-image",
action="append", action="append",
default=[], default=[],
help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).", help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).",
) )
return parser.parse_args() return parser.parse_args()
@@ -71,37 +71,44 @@ def build_message_content(prompt: str, input_images):
content.append({"type": "image_url", "image_url": {"url": data_url}}) content.append({"type": "image_url", "image_url": {"url": data_url}})
return content return content
def parse_data_url(data_url: str): def parse_data_url(data_url: str):
if not data_url.startswith("data:") or ";base64," not in data_url: if not data_url.startswith("data:") or ";base64," not in data_url:
raise ValueError("Image URL is not a base64 data URL.") raise ValueError("Image URL is not a base64 data URL.")
header, encoded = data_url.split(",", 1) header, encoded = data_url.split(",", 1)
mime = header[5:].split(";", 1)[0] mime = header[5:].split(";", 1)[0]
try: try:
raw = base64.b64decode(encoded) raw = base64.b64decode(encoded)
except Exception as e: except Exception as e:
raise SystemExit(f"Failed to decode base64 image payload: {e}") raise SystemExit(f"Failed to decode base64 image payload: {e}")
return mime, raw return mime, raw
def resolve_output_paths(filename: str, image_count: int, mime: str): def resolve_output_paths(filename: str, image_count: int, mime: str):
output_path = Path(filename) output_path = Path(filename)
suffix = output_path.suffix suffix = output_path.suffix
if not suffix:
suffix = MIME_TO_EXT.get(mime, ".png")
output_path = output_path.with_suffix(suffix)
# Create parent directory if it doesn't exist (for absolute paths) # Validate/correct suffix matches MIME type
if output_path.parent and str(output_path.parent) != '.': expected_suffix = MIME_TO_EXT.get(mime, ".png")
output_path.parent.mkdir(parents=True, exist_ok=True) if suffix and suffix.lower() != expected_suffix.lower():
print(f"Warning: filename extension '{suffix}' doesn't match returned MIME type '{mime}'. Using '{expected_suffix}' instead.")
suffix = expected_suffix
elif not suffix:
suffix = expected_suffix
if image_count == 1: output_path = output_path.with_suffix(suffix)
return [output_path]
paths = [] # Create parent directory if it doesn't exist (for paths with parent directories, absolute or relative)
for index in range(image_count): if output_path.parent and str(output_path.parent) != '.':
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}") output_path.parent.mkdir(parents=True, exist_ok=True)
paths.append(numbered)
return paths if image_count == 1:
return [output_path]
paths = []
for index in range(image_count):
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
paths.append(numbered)
return paths
def extract_image_url(image): def extract_image_url(image):