Skip to content

Commit

Permalink
Update tests to reflect #13
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Apr 14, 2024
1 parent 0e16d4a commit 107693a
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 50 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/idefics2/image_processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = True,
do_image_splitting = False,
do_image_splitting=False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -374,7 +374,7 @@ def split_image(
Split an image into 4 equal sub-images, and the concatenate that sequence with the original image.
That means that a single image becomes a sequence of 5 images.
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
Args:
image (`np.ndarray`):
Images to split.
Expand All @@ -390,7 +390,7 @@ def split_image(
self._crop(image, mid_width, 0, width, mid_height, input_data_format),
self._crop(image, 0, mid_height, mid_width, height, input_data_format),
self._crop(image, mid_width, mid_height, width, height, input_data_format),
image
image,
]

def preprocess(
Expand Down
48 changes: 5 additions & 43 deletions src/transformers/models/idefics2/processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,49 +41,6 @@ def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)


def build_string_from_input(prompt, image_seq_len, bos_token, image_token, fake_image_token):
"""
Builds a string from the input prompt and image tokens.
For example, for the call:
build_string_from_input(
prompt=["Initial str", img1, img2, "mid str", img3],
image_seq_len=2,
bos_token="<s>",
image_token="<im>",
fake_image_token="<fake>"
)
The output will be:
"<s>Initial str<fake><im><im><fake><im><im><fake>mid str<fake><im><im><fake>"
Args:
prompt (`List[Union[str, ImageInput]]`): The input prompt.
image_seq_len (`int`): The length of the image sequence.
bos_token (`str`): The beginning of sentence token.
image_token (`str`): The image token.
fake_image_token (`str`): The fake image token.
"""
input_strings = []
input_strings.append(f"{bos_token}")
open_image_tag = False
for elem in prompt:
if is_image_or_image_url(elem):
image_string = f"{fake_image_token}{image_token * image_seq_len}" * (5 if do_image_splitting else 1)
input_strings.append(image_string)
open_image_tag = True
else:
if open_image_tag:
input_strings.append(f"{fake_image_token}")
open_image_tag = False
input_strings.append(elem)
if open_image_tag:
input_strings.append(f"{fake_image_token}")
return "".join(input_strings)


class Idefics2Processor(ProcessorMixin):
r"""
Constructs a IDEFICS2 processor which wraps a LLama tokenizer and IDEFICS2 image processor into a single processor.
Expand Down Expand Up @@ -223,6 +180,11 @@ def __call__(
fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
image_str = f"{fake_image_token}{image_token * image_seq_len}{fake_image_token}"

if self.image_processor.do_image_splitting:
# A single image token is split into 4 patches + 1 original image
image_str = image_str * 5

prompt_strings = []
for sample in text:
n_images_in_text.append(sample.count(image_token))
Expand Down
72 changes: 68 additions & 4 deletions tests/models/idefics2/test_processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class Idefics2ProcessorTest(unittest.TestCase):
def setUp(self):
self.processor = Idefics2Processor.from_pretrained("/fsx/m4/victor/idefics2-tfrm-compatible")
self.processor = Idefics2Processor.from_pretrained("amyeroberts/idefics2", image_seq_len=2)
self.image1 = Image.open(
BytesIO(
requests.get(
Expand Down Expand Up @@ -56,6 +56,10 @@ def setUp(self):
self.image_seq_len = self.processor.image_seq_len

def test_process_interleaved_images_prompts_no_image_splitting(self):
old_image_splitting = self.processor.image_processor.do_image_splitting

self.processor.image_processor.do_image_splitting = False

# Test that a single image is processed correctly
inputs = self.processor(images=self.image1)
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 3, 653, 980))
Expand Down Expand Up @@ -110,6 +114,69 @@ def test_process_interleaved_images_prompts_no_image_splitting(self):
self.assertEqual(inputs['pixel_attention_mask'].shape, (2, 2, 767, 980))
# fmt: on

self.processor.image_processor.do_image_splitting = old_image_splitting

def test_process_interleaved_images_prompts_image_splitting(self):
old_image_splitting = self.processor.image_processor.do_image_splitting

self.processor.image_processor.do_image_splitting = True

# Test that a single image is processed correctly
inputs = self.processor(images=self.image1)
self.assertEqual(inputs["pixel_values"].shape, (1, 5, 3, 653, 980))
self.assertEqual(inputs["pixel_attention_mask"].shape, (1, 5, 653, 980))
# fmt: on

# Test a single sample with image and text
image_str = "<image>"
text_str = "In this image, we see"
text = image_str + text_str
inputs = self.processor(text=text, images=self.image1)

# fmt: off
tokenized_sentence = self.processor.tokenizer(text_str, add_special_tokens=False)
expected_input_ids = [[self.bos_token_id] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
self.assertEqual(inputs["pixel_values"].shape, (1, 5, 3, 653, 980))
self.assertEqual(inputs["pixel_attention_mask"].shape, (1, 5, 653, 980))
# fmt: on

# Test that batch is correctly processed
image_str = "<image>"
text_str_1 = "In this image, we see"
text_str_2 = "bla, bla"

text = [
image_str + text_str_1,
text_str_2 + image_str + image_str,
]
images = [[self.image1], [self.image2, self.image3]]

inputs = self.processor(text=text, images=images, padding=True)

# fmt: off
tokenized_sentence_1 = self.processor.tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = self.processor.tokenizer(text_str_2, add_special_tokens=False)
expected_input_ids_1 = [self.bos_token_id] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id] + tokenized_sentence_1["input_ids"]
expected_input_ids_2 = [self.bos_token_id] + tokenized_sentence_2["input_ids"] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id]
# Pad the first input to match the second input
pad_len = len(expected_input_ids_2) - len(expected_input_ids_1)
padded_expected_input_ids_1 = [0] * pad_len + expected_input_ids_1

self.assertEqual(
inputs["input_ids"], [padded_expected_input_ids_1, expected_input_ids_2]
)
self.assertEqual(
inputs["attention_mask"],
[[0] * pad_len + [1] * len(expected_input_ids_1), [1] * len(expected_input_ids_2)]
)
self.assertEqual(inputs['pixel_values'].shape, (2, 10, 3, 767, 980))
self.assertEqual(inputs['pixel_attention_mask'].shape, (2, 10, 767, 980))
# fmt: on

self.processor.image_processor.do_image_splitting = old_image_splitting

def test_add_special_tokens_processor(self):
image_str = "<image>"
text_str = "In this image, we see"
Expand Down Expand Up @@ -160,7 +227,4 @@ def test_apply_chat_template(self):
"User: And who is that?<end_of_utterance>\n"
"Assistant:"
)
if self.processor.image_processor.do_image_splitting:
expected_rendered = expected_rendered.replace("<fake_token_around_image><image><image>", "<fake_token_around_image><image><image>" * 5)

self.assertEqual(rendered, expected_rendered)

0 comments on commit 107693a

Please sign in to comment.