Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat template: return vectorized output in processors #34275

Merged
merged 42 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d66a928
update chat template
zucchini-nlp Oct 17, 2024
2bff795
Merge branch 'main' into chat-template-vlms
zucchini-nlp Oct 25, 2024
3c24aff
style
zucchini-nlp Oct 25, 2024
710edd1
fix tests
zucchini-nlp Oct 25, 2024
1bf58f3
Merge branch 'main' into chat-template-vlms
zucchini-nlp Oct 25, 2024
76d24ae
Merge branch 'main' into chat-template-vlms
zucchini-nlp Oct 29, 2024
eb588d1
Update src/transformers/image_utils.py
zucchini-nlp Oct 29, 2024
3de67e0
typehints + docs
zucchini-nlp Oct 29, 2024
bcf3dac
fix tests
zucchini-nlp Oct 29, 2024
87205d7
Merge branch 'main' into chat-template-vlms
zucchini-nlp Oct 29, 2024
6282694
remove unnecessary warnings
zucchini-nlp Oct 29, 2024
690c314
forgot code style :(
zucchini-nlp Oct 29, 2024
9049d64
allow users to pass backend and num frames
zucchini-nlp Oct 29, 2024
243b4c3
Update docs/source/en/chat_templating.md
zucchini-nlp Oct 30, 2024
899d20d
Update src/transformers/image_utils.py
zucchini-nlp Oct 30, 2024
47272f8
Update src/transformers/image_utils.py
zucchini-nlp Oct 30, 2024
fc8ba58
Update src/transformers/image_utils.py
zucchini-nlp Oct 30, 2024
8b0ddd7
Update src/transformers/image_utils.py
zucchini-nlp Oct 30, 2024
d2d27fb
Update src/transformers/image_utils.py
zucchini-nlp Oct 30, 2024
1adfbca
Update src/transformers/image_utils.py
zucchini-nlp Oct 30, 2024
d0209e2
Update src/transformers/processing_utils.py
zucchini-nlp Oct 30, 2024
cde21be
Merge branch 'main' into chat-template-vlms
zucchini-nlp Oct 30, 2024
34ee690
typo fix
zucchini-nlp Nov 4, 2024
3cd24ac
merge main
zucchini-nlp Nov 4, 2024
91057e4
style
zucchini-nlp Nov 4, 2024
5edb363
address comments
zucchini-nlp Nov 15, 2024
04080ea
Merge branch 'main' into chat-template-vlms
zucchini-nlp Nov 15, 2024
eb450f8
Merge branch 'main' into chat-template-vlms
zucchini-nlp Nov 18, 2024
9cc74a4
align with "pipeline" template
zucchini-nlp Nov 19, 2024
39724ef
update docs
zucchini-nlp Nov 19, 2024
72368f7
update docs
zucchini-nlp Nov 19, 2024
376e808
merge main
zucchini-nlp Jan 8, 2025
de58cb0
unpack for all kwargs?
zucchini-nlp Jan 8, 2025
71a82b5
wrong conflict resolution while rebasing
zucchini-nlp Jan 8, 2025
4e62720
tmp
zucchini-nlp Jan 8, 2025
45289f3
update docs
zucchini-nlp Jan 9, 2025
503b153
Merge branch 'main' into chat-template-vlms
zucchini-nlp Jan 9, 2025
2b54a52
Update docs/source/en/chat_templating.md
zucchini-nlp Jan 10, 2025
3c3441e
Update docs/source/en/chat_templating.md
zucchini-nlp Jan 10, 2025
4600728
Update docs/source/en/chat_templating.md
zucchini-nlp Jan 10, 2025
39875be
Update docs/source/en/chat_templating.md
zucchini-nlp Jan 10, 2025
db2ec0c
Merge branch 'main' into chat-template-vlms
zucchini-nlp Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ of text (as is the case with a standard language model), the model instead conti
of one or more **messages**, each of which includes a **role**, like "user" or "assistant", as well as message text.

Much like tokenization, different models expect very different input formats for chat. This is the reason we added
**chat templates** as a feature. Chat templates are part of the tokenizer. They specify how to convert conversations,
**chat templates** as a feature. Chat templates are part of the tokenizer for text-only LLMs or processor for multimodal LLMs. They specify how to convert conversations,
represented as lists of messages, into a single tokenizable string in the format that the model expects.

Let's make this concrete with a quick example using the `mistralai/Mistral-7B-Instruct-v0.1` model:
Expand Down Expand Up @@ -66,10 +66,12 @@ for you, allowing you to write universal code that works for any model.
## How do I use chat templates?

As you can see in the example above, chat templates are easy to use. Simply build a list of messages, with `role`
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] method. Once you do that,
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] or [`~ProcessorMixin.apply_chat_template`] method
depending on what type of model you are using. Once you do that,
you'll get output that's ready to go! When using chat templates as input for model generation, it's also a good idea
to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts).

# Usage with text-only LLMs
Here's an example of preparing input for `model.generate()`, using `Zephyr` again:

```python
Expand Down Expand Up @@ -116,6 +118,46 @@ How many helicopters can a human eat in one sitting?</s>
Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
```

# Usage with multimodal LLMs

For multimodal LLMs such as [LLaVA](https://huggingface.co/llava-hf) the prompts can be formatted in a similar way,
with the only differenct that you need to pass input images/videos as well along with the text. Therefore each "content"
has to be a list containing either a text or an image/video content.

Here's an example of preparing input for using `LLaVA` model:

```python
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration

model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id) # You may want to use bfloat16 and/or move to GPU here
processor = AutoProcessor.from_pretrained(model_id)

messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a friendly chatbot who always responds in the style of a pirate"}],
},
{
"role": "user",
"content": [
{"type": "image", "image": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{"type": "text", "text": "What are these?"},
],
},
]

processed_chat = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt")
print(processor.batch_decode(processed_chat["input_ids"][:, :30]))
```
This will yield a string in the input format that LLaVA expects with a bunch of `<image>` tokens at the end.
The `<image>`tokens are there as a placeholder and each one will be replaced by image embeddings when running the model
forward call. And the `processed_chat` can be further passed into `model.generate()` to generate text.
```text
'<|im_start|>system
You are a friendly chatbot who always responds in the style of a pirate<|im_end|><|im_start|>user <image><image><image><image><image><image><image><image>'
```

Arr, 'twas easy after all!

## Is there an automated pipeline for chat?
Expand Down
213 changes: 213 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import base64
import os
from contextlib import redirect_stdout
from io import BytesIO
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -24,13 +25,17 @@

from .utils import (
ExplicitEnum,
is_av_available,
is_cv2_available,
is_decord_available,
is_jax_tensor,
is_numpy_array,
is_tf_tensor,
is_torch_available,
is_torch_tensor,
is_torchvision_available,
is_vision_available,
is_yt_dlp_available,
logging,
requires_backends,
to_numpy,
Expand All @@ -55,6 +60,7 @@
PILImageResampling = PIL.Image

if is_torchvision_available():
from torchvision import io as torchvision_io
from torchvision.transforms import InterpolationMode

pil_torch_interpolation_mapping = {
Expand All @@ -66,6 +72,17 @@
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
}

if is_decord_available():
from decord import VideoReader, cpu

if is_av_available():
import av

if is_cv2_available():
import cv2

if is_yt_dlp_available():
from yt_dlp import YoutubeDL
Comment on lines +76 to +86
Copy link
Member

@hmellor hmellor Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block breaks lazy importing of cv2 which vllm strictly enforces. It happens when vLLM imports from transformers.image_utils import ImageInput. vLLM cannot upgrade to v4.49.0 because of it vllm-project/vllm#13905.

Would it be possible to delay this import? This would be preferable to lazily importing ImageInput everywhere it's used in vLLM.

cc @ArthurZucker


if TYPE_CHECKING:
if is_torch_available():
Expand Down Expand Up @@ -385,6 +402,202 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def get_uniform_frame_indices(total_num_frames: int, num_frames: int = None):
"""
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
when loading a video.

Args:
total_num_frames (`int`):
Total number of frames that a video has.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.

Returns:
np.ndarray: np array of frame indices that will be sampled.
"""
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
else:
indices = np.arange(0, total_num_frames).astype(int)
return indices


def read_video_opencv(video_path: str, num_frames: int = None):
"""
Decode the video with open-cv decoder.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
video = cv2.VideoCapture(video_path)
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)

index = 0
frames = []
while video.isOpened():
success, frame = video.read()
if index in indices:
height, width, channel = frame.shape
frames.append(frame[0:height, 0:width, 0:channel])
if success:
index += 1
if index >= total_num_frames:
break

video.release()
return np.stack(frames)


def read_video_decord(video_path: str, num_frames: int = None):
"""
Decode the video with Decord decoder.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
indices = get_uniform_frame_indices(total_num_frames=len(vr), num_frames=num_frames)
frames = vr.get_batch(indices).asnumpy()
return frames


def read_video_pyav(video_path: str, num_frames: int = None):
"""
Decode the video with PyAV decoder.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
container = av.open(video_path)

# sample uniformly "num_frames" frames from the video
total_num_frames = container.streams.video[0].frames
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)

frames = []
container.seek(0)
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= 0 and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])


def read_video_torchvision(video_path: str, num_frames: int = None):
"""
Decode the video with torchvision decoder.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
"""
video, _, info = torchvision_io.read_video(
video_path,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
)

if num_frames is not None:
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
return video[idx]

return video


VIDEO_DECODERS = {
"decord": read_video_decord,
"opencv": read_video_opencv,
"pyav": read_video_pyav,
"torchvision": read_video_torchvision,
}


def load_video(video: Union[str, "VideoInput"], num_frames: int = None, backend: str = "opencv") -> np.array:
"""
Loads `video` to a numpy array.

Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.

Returns:
`np.array`: A numpy array of shape (num_frames, channels, height, width).
"""
if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
buffer = BytesIO()
with redirect_stdout(buffer), YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video).content)
Comment on lines +560 to +569
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some additional kwargs might be required here, e.g. timeout, but probably fine for now

elif os.path.isfile(video):
file_obj = video
elif is_valid_image(video) or (isinstance(video, (list, tuple) and is_valid_image(video[0]))):
file_obj = None
else:
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")

# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend in ["opencv", "torchvision"]:
raise ValueError(
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
)

if file_obj is None:
return video

if (
(not is_decord_available() and backend == "decord")
or (not is_av_available() and backend == "pyav")
or (not is_cv2_available() and backend == "opencv")
or (not is_torchvision_available() and backend == "torchvision")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)

video_decoder = VIDEO_DECODERS[backend]
video = video_decoder(file_obj)
return video


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
Expand Down
Loading
Loading