fix(generate_image): improve input image handling and validate output filename extensions

This commit is contained in:
nblog
2026-02-09 17:24:24 +08:00
parent 5efb7329a3
commit b8bbc75db2

View File

@@ -38,10 +38,10 @@ def parse_args():
help="Output resolution: 1K, 2K, or 4K.",
)
parser.add_argument(
"--input-image",
action="append",
default=[],
help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).",
"--input-image",
action="append",
default=[],
help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).",
)
return parser.parse_args()
@@ -71,37 +71,44 @@ def build_message_content(prompt: str, input_images):
content.append({"type": "image_url", "image_url": {"url": data_url}})
return content
def parse_data_url(data_url: str):
if not data_url.startswith("data:") or ";base64," not in data_url:
def parse_data_url(data_url: str):
if not data_url.startswith("data:") or ";base64," not in data_url:
raise ValueError("Image URL is not a base64 data URL.")
header, encoded = data_url.split(",", 1)
mime = header[5:].split(";", 1)[0]
try:
header, encoded = data_url.split(",", 1)
mime = header[5:].split(";", 1)[0]
try:
raw = base64.b64decode(encoded)
except Exception as e:
except Exception as e:
raise SystemExit(f"Failed to decode base64 image payload: {e}")
return mime, raw
return mime, raw
def resolve_output_paths(filename: str, image_count: int, mime: str):
output_path = Path(filename)
suffix = output_path.suffix
if not suffix:
suffix = MIME_TO_EXT.get(mime, ".png")
output_path = output_path.with_suffix(suffix)
output_path = Path(filename)
suffix = output_path.suffix
# Create parent directory if it doesn't exist (for absolute paths)
if output_path.parent and str(output_path.parent) != '.':
output_path.parent.mkdir(parents=True, exist_ok=True)
# Validate/correct suffix matches MIME type
expected_suffix = MIME_TO_EXT.get(mime, ".png")
if suffix and suffix.lower() != expected_suffix.lower():
print(f"Warning: filename extension '{suffix}' doesn't match returned MIME type '{mime}'. Using '{expected_suffix}' instead.")
suffix = expected_suffix
elif not suffix:
suffix = expected_suffix
if image_count == 1:
return [output_path]
output_path = output_path.with_suffix(suffix)
paths = []
for index in range(image_count):
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
paths.append(numbered)
return paths
# Create parent directory if it doesn't exist (for paths with parent directories, absolute or relative)
if output_path.parent and str(output_path.parent) != '.':
output_path.parent.mkdir(parents=True, exist_ok=True)
if image_count == 1:
return [output_path]
paths = []
for index in range(image_count):
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
paths.append(numbered)
return paths
def extract_image_url(image):