Skip to content

Commit

Permalink
LLaVa-Next: Update docs with batched inference (#30857)
Browse files Browse the repository at this point in the history
* update docs with batch ex

* Update docs/source/en/model_doc/llava_next.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* accept nested list of img

---------

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
  • Loading branch information
zucchini-nlp and NielsRogge authored May 20, 2024
1 parent cd6bd0a commit 5d0bf59
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
41 changes: 41 additions & 0 deletions docs/source/en/model_doc/llava_next.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/

## Usage example

### Single image inference

Here's how to load the model and perform inference in half-precision (`torch.float16`):

```python
Expand All @@ -94,6 +96,45 @@ output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
```

### Multi image inference

LLaVa-Next can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it:

```python
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaNextForConditionalGeneration

# Load the model in half-precision
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, device_map="auto")
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")

# Get three different images
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image_stop = Image.open(requests.get(url, stream=True).raw)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_cats = Image.open(requests.get(url, stream=True).raw)

url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
image_snowman = Image.open(requests.get(url, stream=True).raw)

# Prepare a batched prompt, where the first one is a multi-turn conversation and the second is not
prompt = [
"[INST] <image>\nWhat is shown in this image? [/INST] There is a red stop sign in the image. [INST] <image>\nWhat about this image? How many cats do you see [/INST]",
"[INST] <image>\nWhat is shown in this image? [/INST]"
]

# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(text=prompt, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=30)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
```

## Model optimization

### Quantization using Bitsandbytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
make_list_of_images,
to_numpy_array,
valid_images,
Expand All @@ -52,6 +53,29 @@
from PIL import Image


def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.
Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]

elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images

elif is_valid_image(images):
return [images]

raise ValueError(f"Could not make batched video from {images}")


def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
"""
Divides an image into patches of a specified size.
Expand Down Expand Up @@ -651,7 +675,7 @@ def preprocess(
do_pad = do_pad if do_pad is not None else self.do_pad
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb

images = make_list_of_images(images)
images = make_batched_images(images)

if not valid_images(images):
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions tests/models/llava_next/test_image_processor_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,21 @@ def test_call_pytorch(self):
@unittest.skip("LlavaNextImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_numpy_4_channels(self):
pass

def test_nested_input(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)

# Test batched as a list of images
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

# Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)

# Image processor should return same pixel values, independently of ipnut format
self.assertTrue((encoded_images_nested == encoded_images).all())

0 comments on commit 5d0bf59

Please sign in to comment.