mirror of
https://github.com/github/awesome-copilot.git
synced 2026-02-20 02:15:12 +00:00
fix(generate_image): improve input image handling and validate output filename extensions
This commit is contained in:
@@ -38,10 +38,10 @@ def parse_args():
|
||||
help="Output resolution: 1K, 2K, or 4K.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-image",
|
||||
action="append",
|
||||
default=[],
|
||||
help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).",
|
||||
"--input-image",
|
||||
action="append",
|
||||
default=[],
|
||||
help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -71,37 +71,44 @@ def build_message_content(prompt: str, input_images):
|
||||
content.append({"type": "image_url", "image_url": {"url": data_url}})
|
||||
return content
|
||||
|
||||
def parse_data_url(data_url: str):
|
||||
if not data_url.startswith("data:") or ";base64," not in data_url:
|
||||
def parse_data_url(data_url: str):
|
||||
if not data_url.startswith("data:") or ";base64," not in data_url:
|
||||
raise ValueError("Image URL is not a base64 data URL.")
|
||||
header, encoded = data_url.split(",", 1)
|
||||
mime = header[5:].split(";", 1)[0]
|
||||
try:
|
||||
header, encoded = data_url.split(",", 1)
|
||||
mime = header[5:].split(";", 1)[0]
|
||||
try:
|
||||
raw = base64.b64decode(encoded)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise SystemExit(f"Failed to decode base64 image payload: {e}")
|
||||
return mime, raw
|
||||
return mime, raw
|
||||
|
||||
|
||||
def resolve_output_paths(filename: str, image_count: int, mime: str):
|
||||
output_path = Path(filename)
|
||||
suffix = output_path.suffix
|
||||
if not suffix:
|
||||
suffix = MIME_TO_EXT.get(mime, ".png")
|
||||
output_path = output_path.with_suffix(suffix)
|
||||
output_path = Path(filename)
|
||||
suffix = output_path.suffix
|
||||
|
||||
# Create parent directory if it doesn't exist (for absolute paths)
|
||||
if output_path.parent and str(output_path.parent) != '.':
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Validate/correct suffix matches MIME type
|
||||
expected_suffix = MIME_TO_EXT.get(mime, ".png")
|
||||
if suffix and suffix.lower() != expected_suffix.lower():
|
||||
print(f"Warning: filename extension '{suffix}' doesn't match returned MIME type '{mime}'. Using '{expected_suffix}' instead.")
|
||||
suffix = expected_suffix
|
||||
elif not suffix:
|
||||
suffix = expected_suffix
|
||||
|
||||
if image_count == 1:
|
||||
return [output_path]
|
||||
output_path = output_path.with_suffix(suffix)
|
||||
|
||||
paths = []
|
||||
for index in range(image_count):
|
||||
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
|
||||
paths.append(numbered)
|
||||
return paths
|
||||
# Create parent directory if it doesn't exist (for paths with parent directories, absolute or relative)
|
||||
if output_path.parent and str(output_path.parent) != '.':
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if image_count == 1:
|
||||
return [output_path]
|
||||
|
||||
paths = []
|
||||
for index in range(image_count):
|
||||
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
|
||||
paths.append(numbered)
|
||||
return paths
|
||||
|
||||
|
||||
def extract_image_url(image):
|
||||
|
||||
Reference in New Issue
Block a user