Skip to content

Commit

Permalink
refactor processing
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Jan 13, 2025
1 parent 1834fab commit 37bb6fc
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image",
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
retrieve_images_in_messages(messages, images)
messages = retrieve_images_in_messages(messages, images)

self.messages = messages

Expand All @@ -73,37 +73,39 @@ def retrieve_images_in_messages(
idx_images = 0
for message in messages:
for content in message["content"]:
if isinstance(content, dict):
if content.get("type") == "image":
for key in ["image", "url", "path", "base64"]:
if key in content:
break
else:
if idx_images < len(images):
content["image"] = images[idx_images]
idx_images += 1
else:
raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
# Add support for OpenAI/TGI chat format
elif content.get("type") == "image_url":
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
# Rewrite content to be in the Transformers chat format
content["type"] = "image"
content["image"] = content["image_url"]["url"]
del content["image_url"]
if not isinstance(content, dict):
continue
content_type = content.get("type")
if content_type == "image":
if not any(key in content for key in ["image", "url", "path", "base64"]):
if idx_images < len(images):
# Insert the image passed as argument in the chat message
content["image"] = images[idx_images]
idx_images += 1
else:
raise ValueError(
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
# Add support for OpenAI/TGI chat format
elif content_type == "image_url":
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
# Rewrite content to be in the Transformers chat format
content["type"] = "image"
content["image"] = content["image_url"]["url"]
del content["image_url"]
else:
raise ValueError(
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
)

# The number of images passed should be consistent with the number of images in the chat without an image key
if idx_images != len(images):
raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)

return messages


@add_end_docstrings(build_pipeline_init_args(has_processor=True))
class ImageTextToTextPipeline(Pipeline):
Expand Down Expand Up @@ -310,33 +312,30 @@ def __call__(
return super().__call__({"images": images, "text": text}, **kwargs)

def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
model_inputs = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
tokenize=True,
return_dict=True,
)
model_inputs["text"] = inputs
return model_inputs
# In case we only have text inputs
if isinstance(inputs, (list, tuple, str)):
images = None
text = inputs
inputs_text = inputs
else:
if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
model_inputs = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
tokenize=True,
return_dict=True,
)
model_inputs["text"] = inputs
return model_inputs
else:
text = inputs["text"]
inputs_text = inputs["text"]
images = inputs["images"]

images = load_images(images)
images = load_images(inputs["images"])
text = inputs["text"]
inputs_text = inputs["text"]

# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
Expand Down

0 comments on commit 37bb6fc

Please sign in to comment.