mirror of
https://github.com/github/awesome-copilot.git
synced 2026-02-23 20:05:12 +00:00
fix(generate_image): improve input image handling and validate output filename extensions
This commit is contained in:
@@ -38,10 +38,10 @@ def parse_args():
|
|||||||
help="Output resolution: 1K, 2K, or 4K.",
|
help="Output resolution: 1K, 2K, or 4K.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-image",
|
"--input-image",
|
||||||
action="append",
|
action="append",
|
||||||
default=[],
|
default=[],
|
||||||
help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).",
|
help=f"Optional input image path (repeatable, max {MAX_INPUT_IMAGES}).",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -71,37 +71,44 @@ def build_message_content(prompt: str, input_images):
|
|||||||
content.append({"type": "image_url", "image_url": {"url": data_url}})
|
content.append({"type": "image_url", "image_url": {"url": data_url}})
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def parse_data_url(data_url: str):
|
def parse_data_url(data_url: str):
|
||||||
if not data_url.startswith("data:") or ";base64," not in data_url:
|
if not data_url.startswith("data:") or ";base64," not in data_url:
|
||||||
raise ValueError("Image URL is not a base64 data URL.")
|
raise ValueError("Image URL is not a base64 data URL.")
|
||||||
header, encoded = data_url.split(",", 1)
|
header, encoded = data_url.split(",", 1)
|
||||||
mime = header[5:].split(";", 1)[0]
|
mime = header[5:].split(";", 1)[0]
|
||||||
try:
|
try:
|
||||||
raw = base64.b64decode(encoded)
|
raw = base64.b64decode(encoded)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise SystemExit(f"Failed to decode base64 image payload: {e}")
|
raise SystemExit(f"Failed to decode base64 image payload: {e}")
|
||||||
return mime, raw
|
return mime, raw
|
||||||
|
|
||||||
|
|
||||||
def resolve_output_paths(filename: str, image_count: int, mime: str):
|
def resolve_output_paths(filename: str, image_count: int, mime: str):
|
||||||
output_path = Path(filename)
|
output_path = Path(filename)
|
||||||
suffix = output_path.suffix
|
suffix = output_path.suffix
|
||||||
if not suffix:
|
|
||||||
suffix = MIME_TO_EXT.get(mime, ".png")
|
|
||||||
output_path = output_path.with_suffix(suffix)
|
|
||||||
|
|
||||||
# Create parent directory if it doesn't exist (for absolute paths)
|
# Validate/correct suffix matches MIME type
|
||||||
if output_path.parent and str(output_path.parent) != '.':
|
expected_suffix = MIME_TO_EXT.get(mime, ".png")
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
if suffix and suffix.lower() != expected_suffix.lower():
|
||||||
|
print(f"Warning: filename extension '{suffix}' doesn't match returned MIME type '{mime}'. Using '{expected_suffix}' instead.")
|
||||||
|
suffix = expected_suffix
|
||||||
|
elif not suffix:
|
||||||
|
suffix = expected_suffix
|
||||||
|
|
||||||
if image_count == 1:
|
output_path = output_path.with_suffix(suffix)
|
||||||
return [output_path]
|
|
||||||
|
|
||||||
paths = []
|
# Create parent directory if it doesn't exist (for paths with parent directories, absolute or relative)
|
||||||
for index in range(image_count):
|
if output_path.parent and str(output_path.parent) != '.':
|
||||||
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
paths.append(numbered)
|
|
||||||
return paths
|
if image_count == 1:
|
||||||
|
return [output_path]
|
||||||
|
|
||||||
|
paths = []
|
||||||
|
for index in range(image_count):
|
||||||
|
numbered = output_path.with_name(f"{output_path.stem}-{index + 1}{suffix}")
|
||||||
|
paths.append(numbered)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
def extract_image_url(image):
|
def extract_image_url(image):
|
||||||
|
|||||||
Reference in New Issue
Block a user