Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat template: save and load correctly for processors #33462

Merged
merged 8 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,12 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME)

processor_dict = self.to_dict()
chat_template = processor_dict.pop("chat_template", None)
if chat_template is not None:
chat_template_json_string = json.dumps({"chat_template": chat_template}, indent=2, sort_keys=True) + "\n"
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
# to avoid serializing chat template in json config file. So let's get it from `self` directly
if self.chat_template is not None:
chat_template_json_string = (
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
)
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")
Expand Down
29 changes: 26 additions & 3 deletions tests/models/llava/test_processor_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
Expand All @@ -32,11 +33,11 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

image_processor = CLIPImageProcessor(do_center_crop=False)
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")

processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer)

processor_kwargs = self.prepare_processor_dict()
processor = LlavaProcessor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
Expand All @@ -48,6 +49,28 @@ def get_image_processor(self, **kwargs):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

def prepare_processor_dict(self):
return {"chat_template": "dummy_template"}

@unittest.skip(
"Skip because the model has no processor kwargs except for chat template and"
"chat template is saved as a separate file. Stop skipping this test when the processor"
"has new kwargs saved in config file."
)
def test_processor_to_json_string(self):
pass

def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())

# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def test_can_load_various_tokenizers(self):
for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]:
processor = LlavaProcessor.from_pretrained(checkpoint)
Expand Down
49 changes: 47 additions & 2 deletions tests/models/llava_next/test_processor_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,65 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest

import torch

from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import AutoProcessor
from transformers import CLIPImageProcessor


@require_vision
class LlavaProcessorTest(unittest.TestCase):
class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = LlavaNextProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

image_processor = CLIPImageProcessor()
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
processor_kwargs = self.prepare_processor_dict()
processor = LlavaNextProcessor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def prepare_processor_dict(self):
return {"chat_template": "dummy_template"}

@unittest.skip(
"Skip because the model has no processor kwargs except for chat template and"
"chat template is saved as a separate file. Stop skipping this test when the processor"
"has new kwargs saved in config file."
)
def test_processor_to_json_string(self):
pass

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())

# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"
Expand Down
29 changes: 27 additions & 2 deletions tests/models/llava_onevision/test_processing_llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
Expand Down Expand Up @@ -40,9 +41,10 @@ def setUp(self):
image_processor = LlavaOnevisionImageProcessor()
video_processor = LlavaOnevisionVideoProcessor()
tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
processor_kwargs = self.prepare_processor_dict()

processor = LlavaOnevisionProcessor(
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs
)
processor.save_pretrained(self.tmpdirname)

Expand All @@ -52,9 +54,32 @@ def get_tokenizer(self, **kwargs):
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def get_Video_processor(self, **kwargs):
def get_video_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor

def prepare_processor_dict(self):
return {"chat_template": "dummy_template"}

@unittest.skip(
"Skip because the model has no processor kwargs except for chat template and"
"chat template is saved as a separate file. Stop skipping this test when the processor"
"has new kwargs saved in config file."
)
def test_processor_to_json_string(self):
pass

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())

# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def tearDown(self):
shutil.rmtree(self.tmpdirname)

Expand Down