Skip to content

Commit

Permalink
minor adjustment to image testing to support png (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludomitch authored Dec 16, 2024
1 parent dc68575 commit 4132491
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 7 deletions.
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ James Braza <james@futurehouse.org> <jamesbraza@gmail.com>
Michael Skarlinski <mskarlinski@futurehouse.org> mskarlin <12701035+mskarlin@users.noreply.github.com>
Ryan-Rhys Griffiths <ryan@futurehouse.org> <ryangriff123@gmail.com>
Siddharth Narayanan <sid@futurehouse.org> <sidnarayanan@users.noreply.github.com>
Ludovico Mitchener <ludo@futurehouse.org> ludomitch
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ repos:
- id: codespell
additional_dependencies: [".[toml]"]
exclude_types: [jupyter]
exclude: '.*\.b64$'
- repo: https://github.com/pappasam/toml-sort
rev: v0.24.2
hooks:
Expand Down
4 changes: 2 additions & 2 deletions src/aviary/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pydantic import BaseModel, Field, field_validator, model_validator

from aviary.utils import check_if_valid_base64, encode_image_to_base64
from aviary.utils import encode_image_to_base64, validate_base64_image

if TYPE_CHECKING:
from logging import LogRecord
Expand Down Expand Up @@ -147,7 +147,7 @@ def create_message(
{
"type": "image_url",
"image_url": {
"url": check_if_valid_base64(image)
"url": validate_base64_image(image)
# If image is a string, assume it's already a base64 encoded image
if isinstance(image, str)
else encode_image_to_base64(image)
Expand Down
8 changes: 5 additions & 3 deletions src/aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ def encode_image_to_base64(img: "np.ndarray") -> str:
)


def check_if_valid_base64(image: str) -> str:
"""Check if the input string is a valid base64 encoded image."""
def validate_base64_image(image: str) -> str:
"""Validate if the input string is a valid base64 encoded image and if it is, return the image."""
try:
base64.b64decode(image)
# Support for inclusion of the data:image/ url prefix
test_image = image.split(",")[1] if image.startswith("data:image/") else image
base64.b64decode(test_image)
except Exception as err:
raise ValueError("Invalid base64 encoded image") from err
return image
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/test_images/sample_png_image.b64

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_dump(self, message: Message, expected: dict) -> None:
(
[
np.zeros((32, 32, 3), dtype=np.uint8), # red square
load_base64_image("sample_image1.b64"),
load_base64_image("sample_jpeg_image.b64"),
],
"What color are these squares? List each color.",
None,
Expand All @@ -151,7 +151,14 @@ def test_dump(self, message: Message, expected: dict) -> None:
),
# Case 4: A string should be converted to a base64 encoded image
(
load_base64_image("sample_image1.b64"),
load_base64_image("sample_jpeg_image.b64"),
"What color is this square?",
None,
2, # 1 image + 1 text
),
# Case 5: A PNG image should be converted to a base64 encoded image
(
load_base64_image("sample_png_image.b64"),
"What color is this square?",
None,
2, # 1 image + 1 text
Expand Down

0 comments on commit 4132491

Please sign in to comment.