diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md
index 78965813b1213..79d032bf8b211 100644
--- a/docs/source/serving/openai_compatible_server.md
+++ b/docs/source/serving/openai_compatible_server.md
@@ -172,12 +172,20 @@ completion = client.chat.completions.create(
   ]
 )
 ```
-Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like
-`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which
-format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify
-between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match
-this, unless explicitly specified.
 
+Most chat templates for LLMs expect the `content` field to be a string, but there are some newer models like 
+`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the
+request. vLLM provides best-effort support to detect this automatically, which is logged as a string like
+*"Detected the chat template content format to be..."*, and internally converts incoming requests to match
+the detected format, which can be one of:
+
+- `"string"`: A string.
+  - Example: `"Hello world"`
+- `"openai"`: A list of dictionaries, similar to OpenAI schema.
+  - Example: `[{"type": "text", "text": "Hello world!"}]`
+
+If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument
+to override which format to use.
 
 ## Command line arguments for the server
 
diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py
index e969d33775d86..93660e6118ca8 100644
--- a/tests/entrypoints/openai/test_serving_chat.py
+++ b/tests/entrypoints/openai/test_serving_chat.py
@@ -26,7 +26,6 @@ class MockModelConfig:
     tokenizer = MODEL_NAME
     trust_remote_code = False
     tokenizer_mode = "auto"
-    chat_template_text_format = "string"
     max_model_len = 100
     tokenizer_revision = None
     multimodal_config = MultiModalConfig()
@@ -49,6 +48,7 @@ async def _async_serving_chat_init():
                                            BASE_MODEL_PATHS,
                                            response_role="assistant",
                                            chat_template=CHAT_TEMPLATE,
+                                           chat_template_content_format="auto",
                                            lora_modules=None,
                                            prompt_adapters=None,
                                            request_logger=None)
@@ -70,6 +70,7 @@ def test_serving_chat_should_set_correct_max_tokens():
                                      BASE_MODEL_PATHS,
                                      response_role="assistant",
                                      chat_template=CHAT_TEMPLATE,
+                                     chat_template_content_format="auto",
                                      lora_modules=None,
                                      prompt_adapters=None,
                                      request_logger=None)
diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py
index 5fa466f8f041f..72477e048eafa 100644
--- a/tests/entrypoints/test_chat_utils.py
+++ b/tests/entrypoints/test_chat_utils.py
@@ -6,15 +6,24 @@
 
 from vllm.assets.image import ImageAsset
 from vllm.config import ModelConfig
-from vllm.entrypoints.chat_utils import (parse_chat_messages,
-                                         parse_chat_messages_futures)
+from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
+                                         parse_chat_messages,
+                                         parse_chat_messages_futures,
+                                         resolve_chat_template_content_format)
 from vllm.entrypoints.llm import apply_hf_chat_template
 from vllm.multimodal import MultiModalDataDict
 from vllm.multimodal.utils import encode_image_base64
 from vllm.transformers_utils.tokenizer_group import TokenizerGroup
 
+from ..utils import VLLM_PATH
+
+EXAMPLES_DIR = VLLM_PATH / "examples"
+
 PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
+ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3"
+QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
 MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
 
 
 @pytest.fixture(scope="function")
@@ -26,7 +35,6 @@ def phi3v_model_config():
                        trust_remote_code=True,
                        dtype="bfloat16",
                        seed=0,
-                       chat_template_text_format="string",
                        limit_mm_per_prompt={
                            "image": 2,
                        })
@@ -94,19 +102,24 @@ def test_parse_chat_messages_single_image(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "text",
-            "text": "What's in the image?"
-        }]
-    }], phi3v_model_config, phi3v_tokenizer)
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [{
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "text",
+                "text": "What's in the image?"
+            }]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
 
     assert conversation == [{
         "role": "user",
@@ -121,19 +134,24 @@ async def test_parse_chat_messages_single_image_async(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_future = parse_chat_messages_futures([{
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "text",
-            "text": "What's in the image?"
-        }]
-    }], phi3v_model_config, phi3v_tokenizer)
+    conversation, mm_future = parse_chat_messages_futures(
+        [{
+            "role":
+            "user",
+            "content": [{
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "text",
+                "text": "What's in the image?"
+            }]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
 
     assert conversation == [{
         "role": "user",
@@ -147,24 +165,29 @@ def test_parse_chat_messages_multiple_images(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "text",
-            "text": "What's in these images?"
-        }]
-    }], phi3v_model_config, phi3v_tokenizer)
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [{
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "text",
+                "text": "What's in these images?"
+            }]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
 
     assert conversation == [{
         "role":
@@ -181,24 +204,29 @@ async def test_parse_chat_messages_multiple_images_async(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_future = parse_chat_messages_futures([{
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "text",
-            "text": "What's in these images?"
-        }]
-    }], phi3v_model_config, phi3v_tokenizer)
+    conversation, mm_future = parse_chat_messages_futures(
+        [{
+            "role":
+            "user",
+            "content": [{
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "text",
+                "text": "What's in these images?"
+            }]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
 
     assert conversation == [{
         "role":
@@ -214,27 +242,31 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type":
-            "text",
-            "text":
-            "What's in <|image_1|> and how does it compare to <|image_2|>?"
-        }]
-    }], phi3v_model_config, phi3v_tokenizer)
-
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [{
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type":
+                "text",
+                "text":
+                "What's in <|image_1|> and how does it compare to <|image_2|>?"
+            }]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
     assert conversation == [{
         "role":
         "user",
@@ -249,26 +281,35 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
-        }, {
-            "type":
-            "text",
-            "text":
-            "What's in <|image_1|> and how does it compare to the other one?"
-        }]
-    }], phi3v_model_config, phi3v_tokenizer)
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [
+                {
+                    "type": "image_url",
+                    "image_url": {
+                        "url": image_url
+                    }
+                },
+                {
+                    "type": "image_url",
+                    "image_url": {
+                        "url": image_url
+                    }
+                },
+                {
+                    "type":
+                    "text",
+                    "text":
+                    "What's in <|image_1|> and how does it compare to the other one?"  # noqa: E501
+                }
+            ]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
 
     assert conversation == [{
         "role":
@@ -285,34 +326,39 @@ def test_parse_chat_messages_multiple_images_across_messages(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [{
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "text",
+                "text": "What's in this image?"
+            }]
         }, {
-            "type": "text",
-            "text": "What's in this image?"
-        }]
-    }, {
-        "role": "assistant",
-        "content": "Some stuff."
-    }, {
-        "role":
-        "user",
-        "content": [{
-            "type": "image_url",
-            "image_url": {
-                "url": image_url
-            }
+            "role": "assistant",
+            "content": "Some stuff."
         }, {
-            "type": "text",
-            "text": "What about this one?"
-        }]
-    }], phi3v_model_config, phi3v_tokenizer)
+            "role":
+            "user",
+            "content": [{
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            }, {
+                "type": "text",
+                "text": "What about this one?"
+            }]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
 
     assert conversation == [
         {
@@ -335,7 +381,6 @@ def test_parse_chat_messages_context_text_format(
     phi3v_model_config,
     phi3v_tokenizer,
 ):
-    phi3v_model_config.chat_template_text_format = "openai"
     conversation, mm_data = parse_chat_messages(
         [{
             "role": "user",
@@ -349,7 +394,11 @@ def test_parse_chat_messages_context_text_format(
         }, {
             "role": "user",
             "content": "What about this one?"
-        }], phi3v_model_config, phi3v_tokenizer)
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="openai",
+    )
 
     assert conversation == [
         {
@@ -389,29 +438,34 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
                 ValueError,
                 match="At most 2 image\\(s\\) may be provided in one request\\."
         ):
-            parse_chat_messages([{
-                "role":
-                "user",
-                "content": [{
-                    "type": "image_url",
-                    "image_url": {
-                        "url": image_url
-                    }
-                }, {
-                    "type": "image_url",
-                    "image_url": {
-                        "url": image_url
-                    }
-                }, {
-                    "type": "image_url",
-                    "image_url": {
-                        "url": image_url
-                    }
-                }, {
-                    "type": "text",
-                    "text": "What's in these images?"
-                }]
-            }], phi3v_model_config, phi3v_tokenizer)
+            parse_chat_messages(
+                [{
+                    "role":
+                    "user",
+                    "content": [{
+                        "type": "image_url",
+                        "image_url": {
+                            "url": image_url
+                        }
+                    }, {
+                        "type": "image_url",
+                        "image_url": {
+                            "url": image_url
+                        }
+                    }, {
+                        "type": "image_url",
+                        "image_url": {
+                            "url": image_url
+                        }
+                    }, {
+                        "type": "text",
+                        "text": "What's in these images?"
+                    }]
+                }],
+                phi3v_model_config,
+                phi3v_tokenizer,
+                content_format="string",
+            )
 
 
 def test_parse_chat_messages_rejects_too_many_images_across_messages(
@@ -427,39 +481,44 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
                 ValueError,
                 match="At most 2 image\\(s\\) may be provided in one request\\."
         ):
-            parse_chat_messages([{
-                "role":
-                "user",
-                "content": [{
-                    "type": "image_url",
-                    "image_url": {
-                        "url": image_url
-                    }
+            parse_chat_messages(
+                [{
+                    "role":
+                    "user",
+                    "content": [{
+                        "type": "image_url",
+                        "image_url": {
+                            "url": image_url
+                        }
+                    }, {
+                        "type": "text",
+                        "text": "What's in this image?"
+                    }]
                 }, {
-                    "type": "text",
-                    "text": "What's in this image?"
-                }]
-            }, {
-                "role": "assistant",
-                "content": "Some stuff."
-            }, {
-                "role":
-                "user",
-                "content": [{
-                    "type": "image_url",
-                    "image_url": {
-                        "url": image_url
-                    }
+                    "role": "assistant",
+                    "content": "Some stuff."
                 }, {
-                    "type": "image_url",
-                    "image_url": {
-                        "url": image_url
-                    }
-                }, {
-                    "type": "text",
-                    "text": "What about these two?"
-                }]
-            }], phi3v_model_config, phi3v_tokenizer)
+                    "role":
+                    "user",
+                    "content": [{
+                        "type": "image_url",
+                        "image_url": {
+                            "url": image_url
+                        }
+                    }, {
+                        "type": "image_url",
+                        "image_url": {
+                            "url": image_url
+                        }
+                    }, {
+                        "type": "text",
+                        "text": "What about these two?"
+                    }]
+                }],
+                phi3v_model_config,
+                phi3v_tokenizer,
+                content_format="string",
+            )
 
 
 def test_parse_chat_messages_multiple_images_uncommon_input(
@@ -467,17 +526,22 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
     phi3v_tokenizer,
     image_url,
 ):
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [
-            "What's in these images?", {
-                "image_url": image_url
-            }, {
-                "image_url": image_url
-            }
-        ]
-    }], phi3v_model_config, phi3v_tokenizer)
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [
+                "What's in these images?", {
+                    "image_url": image_url
+                }, {
+                    "image_url": image_url
+                }
+            ]
+        }],
+        phi3v_model_config,
+        phi3v_tokenizer,
+        content_format="string",
+    )
 
     assert conversation == [{
         "role":
@@ -495,16 +559,21 @@ def test_mllama_single_image(
     image_url,
 ):
     """Ensures that a single image is parsed correctly mllama."""
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [{
-            'type': 'text',
-            'text': 'The content of this image is:'
-        }, {
-            "image_url": image_url
-        }]
-    }], mllama_model_config, mllama_tokenizer)
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [{
+                'type': 'text',
+                'text': 'The content of this image is:'
+            }, {
+                "image_url": image_url
+            }]
+        }],
+        mllama_model_config,
+        mllama_tokenizer,
+        content_format="openai",
+    )
     _assert_mm_data_is_image_input(mm_data, 1)
     assert conversation == [{
         'role':
@@ -524,26 +593,31 @@ def test_mllama_interleaved_images(
     image_url,
 ):
     """Ensures that multiple image are parsed as interleaved dicts."""
-    conversation, mm_data = parse_chat_messages([{
-        "role":
-        "user",
-        "content": [
-            {
-                'type': 'text',
-                'text': 'The content of the first image is:'
-            },
-            {
-                "image_url": image_url
-            },
-            {
-                'type': 'text',
-                'text': 'The content of the second image is:'
-            },
-            {
-                "image_url": image_url
-            },
-        ]
-    }], mllama_model_config, mllama_tokenizer)
+    conversation, mm_data = parse_chat_messages(
+        [{
+            "role":
+            "user",
+            "content": [
+                {
+                    'type': 'text',
+                    'text': 'The content of the first image is:'
+                },
+                {
+                    "image_url": image_url
+                },
+                {
+                    'type': 'text',
+                    'text': 'The content of the second image is:'
+                },
+                {
+                    "image_url": image_url
+                },
+            ]
+        }],
+        mllama_model_config,
+        mllama_tokenizer,
+        content_format="openai",
+    )
     _assert_mm_data_is_image_input(mm_data, 2)
     assert conversation == [{
         'role':
@@ -626,6 +700,7 @@ def get_conversation(is_hf: bool):
         vllm_conversation,
         model_config,
         tokenizer_group,
+        content_format="openai",
     )
 
     vllm_result = apply_hf_chat_template(
@@ -636,3 +711,89 @@ def get_conversation(is_hf: bool):
     )
 
     assert hf_result == vllm_result
+
+
+# yapf: disable
+@pytest.mark.parametrize(
+    ("model", "expected_format"),
+    [(PHI3V_MODEL_ID, "string"),
+     (QWEN2VL_MODEL_ID, "openai"),
+     (ULTRAVOX_MODEL_ID, "string"),
+     (MLLAMA_MODEL_ID, "openai"),
+     (LLAMA_GUARD_MODEL_ID, "openai")],
+)
+# yapf: enable
+def test_resolve_content_format_hf_defined(model, expected_format):
+    tokenizer_group = TokenizerGroup(
+        model,
+        enable_lora=False,
+        max_num_seqs=5,
+        max_input_length=None,
+    )
+    tokenizer = tokenizer_group.tokenizer
+
+    chat_template = tokenizer.chat_template
+    assert isinstance(chat_template, str)
+
+    print("[TEXT]")
+    print(chat_template)
+    print("[AST]")
+    print(_try_extract_ast(chat_template))
+
+    resolved_format = resolve_chat_template_content_format(
+        None,  # Test detecting the tokenizer's chat_template
+        "auto",
+        tokenizer,
+    )
+
+    assert resolved_format == expected_format
+
+
+# yapf: disable
+@pytest.mark.parametrize(
+    ("template_path", "expected_format"),
+    [("template_alpaca.jinja", "string"),
+     ("template_baichuan.jinja", "string"),
+     ("template_blip2.jinja", "string"),
+     ("template_chatglm.jinja", "string"),
+     ("template_chatglm2.jinja", "string"),
+     ("template_chatml.jinja", "string"),
+     ("template_falcon_180b.jinja", "string"),
+     ("template_falcon.jinja", "string"),
+     ("template_inkbot.jinja", "string"),
+     ("template_llava.jinja", "string"),
+     ("template_vlm2vec.jinja", "openai"),
+     ("tool_chat_template_granite_20b_fc.jinja", "string"),
+     ("tool_chat_template_hermes.jinja", "string"),
+     ("tool_chat_template_internlm2_tool.jinja", "string"),
+     ("tool_chat_template_llama3.1_json.jinja", "string"),
+     ("tool_chat_template_llama3.2_json.jinja", "string"),
+     ("tool_chat_template_mistral_parallel.jinja", "string"),
+     ("tool_chat_template_mistral.jinja", "string")],
+)
+# yapf: enable
+def test_resolve_content_format_examples(template_path, expected_format):
+    tokenizer_group = TokenizerGroup(
+        PHI3V_MODEL_ID,
+        enable_lora=False,
+        max_num_seqs=5,
+        max_input_length=None,
+    )
+    dummy_tokenizer = tokenizer_group.tokenizer
+    dummy_tokenizer.chat_template = None
+
+    chat_template = load_chat_template(EXAMPLES_DIR / template_path)
+    assert isinstance(chat_template, str)
+
+    print("[TEXT]")
+    print(chat_template)
+    print("[AST]")
+    print(_try_extract_ast(chat_template))
+
+    resolved_format = resolve_chat_template_content_format(
+        chat_template,
+        "auto",
+        dummy_tokenizer,
+    )
+
+    assert resolved_format == expected_format
diff --git a/vllm/config.py b/vllm/config.py
index 1c190da1d327e..64b2f75e092de 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -155,7 +155,6 @@ def __init__(
             limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
             use_async_output_proc: bool = True,
             config_format: ConfigFormat = ConfigFormat.AUTO,
-            chat_template_text_format: str = "string",
             hf_overrides: Optional[HfOverrides] = None,
             mm_processor_kwargs: Optional[Dict[str, Any]] = None,
             override_neuron_config: Optional[Dict[str, Any]] = None,
@@ -216,7 +215,6 @@ def __init__(
             self.model, revision)
         self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
         self.use_async_output_proc = use_async_output_proc
-        self.chat_template_text_format = chat_template_text_format
         self.mm_processor_kwargs = mm_processor_kwargs
 
         # Set enforce_eager to False if the value is unset.
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 89cdb4e117515..2d09c48a9ec4b 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -91,7 +91,6 @@ class EngineArgs:
     task: TaskOption = "auto"
     skip_tokenizer_init: bool = False
     tokenizer_mode: str = 'auto'
-    chat_template_text_format: str = 'string'
     trust_remote_code: bool = False
     allowed_local_media_path: str = ""
     download_dir: Optional[str] = None
@@ -259,14 +258,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
             'fast tokenizer if available.\n* "slow" will '
             'always use the slow tokenizer. \n* '
             '"mistral" will always use the `mistral_common` tokenizer.')
-        parser.add_argument(
-            '--chat-template-text-format',
-            type=str,
-            default=EngineArgs.chat_template_text_format,
-            choices=['string', 'openai'],
-            help='The format to render text content within a chat template. '
-            '"string" will keep the content field as a string whereas '
-            '"openai" will parse content in the current OpenAI format.')
         parser.add_argument('--trust-remote-code',
                             action='store_true',
                             help='Trust remote code from huggingface.')
@@ -895,7 +886,6 @@ def create_model_config(self) -> ModelConfig:
             # We know this is not None because we set it in __post_init__
             tokenizer=cast(str, self.tokenizer),
             tokenizer_mode=self.tokenizer_mode,
-            chat_template_text_format=self.chat_template_text_format,
             trust_remote_code=self.trust_remote_code,
             allowed_local_media_path=self.allowed_local_media_path,
             dtype=self.dtype,
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 73abcb792ca78..58bd5cec044b8 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -262,8 +262,7 @@ def __init__(
             "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
             "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
             "use_async_output_proc=%s, use_cached_outputs=%s, "
-            "chat_template_text_format=%s, mm_processor_kwargs=%s, "
-            "pooler_config=%r)",
+            "mm_processor_kwargs=%s, pooler_config=%r)",
             VLLM_VERSION,
             model_config.model,
             speculative_config,
@@ -296,7 +295,6 @@ def __init__(
             cache_config.enable_prefix_caching,
             model_config.use_async_output_proc,
             use_cached_outputs,
-            model_config.chat_template_text_format,
             model_config.mm_processor_kwargs,
             model_config.pooler_config,
         )
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 3ca460c47c3bd..abee5ac46391c 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -2,12 +2,14 @@
 import codecs
 import json
 from abc import ABC, abstractmethod
-from collections import defaultdict
+from collections import defaultdict, deque
 from functools import lru_cache, partial
 from pathlib import Path
 from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
                     Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
 
+import jinja2.nodes
+import transformers.utils.chat_template_utils as hf_chat_utils
 # yapf conflicts with isort for this block
 # yapf: disable
 from openai.types.chat import (ChatCompletionAssistantMessageParam,
@@ -153,6 +155,199 @@ class ConversationMessage(TypedDict, total=False):
     """The tool calls generated by the model, such as function calls."""
 
 
+# Passed in by user
+ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
+
+# Used internally
+_ChatTemplateContentFormat = Literal["string", "openai"]
+
+
+def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
+    if isinstance(node, jinja2.nodes.Name):
+        return node.ctx == "load" and node.name == varname
+
+    return False
+
+
+def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
+    if isinstance(node, jinja2.nodes.Getitem):
+        return (_is_var_access(node.node, varname)
+                and isinstance(node.arg, jinja2.nodes.Const)
+                and node.arg.value == key)
+
+    if isinstance(node, jinja2.nodes.Getattr):
+        return _is_var_access(node.node, varname) and node.attr == key
+
+    return False
+
+
+def _is_var_or_elems_access(
+    node: jinja2.nodes.Node,
+    varname: str,
+    key: Optional[str] = None,
+) -> bool:
+    if isinstance(node, jinja2.nodes.Filter):
+        return (node.node is not None
+                and _is_var_or_elems_access(node.node, varname, key))
+    if isinstance(node, jinja2.nodes.Test):
+        return _is_var_or_elems_access(node.node, varname, key)
+
+    if (isinstance(node, jinja2.nodes.Getitem)
+            and isinstance(node.arg, jinja2.nodes.Slice)):
+        return _is_var_or_elems_access(node.node, varname, key)
+
+    # yapf: disable
+    return (
+        _is_attr_access(node, varname, key) if key
+        else _is_var_access(node, varname)
+    ) # yapf: enable
+
+
+def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
+    # Global variable that is implicitly defined at the root
+    yield root, varname
+
+    # Iterative BFS
+    related_varnames = deque([varname])
+    while related_varnames:
+        related_varname = related_varnames.popleft()
+
+        for assign_ast in root.find_all(jinja2.nodes.Assign):
+            lhs = assign_ast.target
+            rhs = assign_ast.node
+
+            if _is_var_or_elems_access(rhs, related_varname):
+                assert isinstance(lhs, jinja2.nodes.Name)
+                yield assign_ast, lhs.name
+
+                # Avoid infinite looping for self-assignment
+                if lhs.name != related_varname:
+                    related_varnames.append(lhs.name)
+
+
+# NOTE: The proper way to handle this is to build a CFG so that we can handle
+# the scope in which each variable is defined, but that is too complicated
+def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
+    messages_varnames = [
+        varname
+        for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
+    ]
+
+    # Search for {%- for message in messages -%} loops
+    for loop_ast in root.find_all(jinja2.nodes.For):
+        loop_iter = loop_ast.iter
+        loop_target = loop_ast.target
+
+        for varname in messages_varnames:
+            if _is_var_or_elems_access(loop_iter, varname):
+                assert isinstance(loop_target, jinja2.nodes.Name)
+                yield loop_ast, loop_target.name
+                break
+
+
+def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
+    message_varnames = [
+        varname for _, varname in _iter_nodes_assign_messages_item(root)
+    ]
+
+    # Search for {%- for content in message['content'] -%} loops
+    for loop_ast in root.find_all(jinja2.nodes.For):
+        loop_iter = loop_ast.iter
+        loop_target = loop_ast.target
+
+        for varname in message_varnames:
+            if _is_var_or_elems_access(loop_iter, varname, "content"):
+                assert isinstance(loop_target, jinja2.nodes.Name)
+                yield loop_ast, loop_target.name
+                break
+
+
+def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
+    try:
+        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
+        return jinja_compiled.environment.parse(chat_template)
+    except Exception:
+        logger.exception("Error when compiling Jinja template")
+        return None
+
+
+def _detect_content_format(
+    chat_template: str,
+    *,
+    default: _ChatTemplateContentFormat,
+) -> _ChatTemplateContentFormat:
+    jinja_ast = _try_extract_ast(chat_template)
+    if jinja_ast is None:
+        return default
+
+    try:
+        next(_iter_nodes_assign_content_item(jinja_ast))
+    except StopIteration:
+        return "string"
+    except Exception:
+        logger.exception("Error when parsing AST of Jinja template")
+        return default
+    else:
+        return "openai"
+
+
+def _resolve_chat_template_content_format(
+    chat_template: Optional[str],
+    given_format: ChatTemplateContentFormatOption,
+    tokenizer: AnyTokenizer,
+) -> _ChatTemplateContentFormat:
+    if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
+        tokenizer_chat_template = tokenizer.chat_template
+    else:
+        tokenizer_chat_template = None
+
+    jinja_text: Optional[str]
+    if isinstance(tokenizer_chat_template, str) and chat_template is None:
+        jinja_text = tokenizer_chat_template
+    elif (isinstance(tokenizer_chat_template, dict)
+            and chat_template in tokenizer_chat_template):
+        jinja_text = tokenizer_chat_template[chat_template]
+    else:
+        jinja_text = load_chat_template(chat_template, is_literal=True)
+
+    detected_format = ("string" if jinja_text is None else
+                       _detect_content_format(jinja_text, default="string"))
+
+    return detected_format if given_format == "auto" else given_format
+
+
+@lru_cache
+def resolve_chat_template_content_format(
+    chat_template: Optional[str],
+    given_format: ChatTemplateContentFormatOption,
+    tokenizer: AnyTokenizer,
+) -> _ChatTemplateContentFormat:
+    detected_format = _resolve_chat_template_content_format(
+        chat_template,
+        given_format,
+        tokenizer,
+    )
+
+    logger.info(
+        "Detected the chat template content format to be '%s'. "
+        "You can set `--chat-template-content-format` to override this.",
+        detected_format,
+    )
+
+    if given_format != "auto" and given_format != detected_format:
+        logger.warning(
+            "You specified `--chat-template-content-format %s` "
+            "which is different from the detected format '%s'. "
+            "If our automatic detection is incorrect, please consider "
+            "opening a GitHub issue so that we can improve it: "
+            "https://github.com/vllm-project/vllm/issues/new/choose",
+            given_format,
+            detected_format,
+        )
+
+    return detected_format
+
+
 ModalityStr = Literal["image", "audio", "video"]
 _T = TypeVar("_T")
 
@@ -407,12 +602,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
 
 
 def load_chat_template(
-        chat_template: Optional[Union[Path, str]]) -> Optional[str]:
+    chat_template: Optional[Union[Path, str]],
+    *,
+    is_literal: bool = False,
+) -> Optional[str]:
     if chat_template is None:
         return None
+
+    if is_literal:
+        if isinstance(chat_template, Path):
+            raise TypeError("chat_template is expected to be read directly "
+                            "from its value")
+
+        return codecs.decode(chat_template, "unicode_escape")
+
     try:
         with open(chat_template) as f:
-            resolved_chat_template = f.read()
+            return f.read()
     except OSError as e:
         if isinstance(chat_template, Path):
             raise
@@ -426,10 +632,7 @@ def load_chat_template(
 
         # If opening a file fails, set chat template to be args to
         # ensure we decode so our escape are interpreted correctly
-        resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
-
-    logger.info("Using supplied chat template:\n%s", resolved_chat_template)
-    return resolved_chat_template
+        return load_chat_template(chat_template, is_literal=True)
 
 
 # TODO: Let user specify how to insert multimodal tokens into prompt
@@ -464,7 +667,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
 _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
 _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
 _VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
-MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
 
 # Define a mapping from part types to their corresponding parsing functions.
 MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
@@ -542,18 +744,12 @@ def _parse_chat_message_content_parts(
     role: str,
     parts: Iterable[ChatCompletionContentPartParam],
     mm_tracker: BaseMultiModalItemTracker,
-    chat_template_text_format: str,
+    *,
+    wrap_dicts: bool,
 ) -> List[ConversationMessage]:
     content: List[Union[str, Dict[str, str]]] = []
 
     mm_parser = mm_tracker.create_parser()
-    model_config = mm_tracker.model_config
-
-    wrap_dicts = (chat_template_text_format == "openai"
-                  or (model_config.task == "embedding"
-                      and model_config.is_multimodal_model)
-                  or (model_config.hf_config.model_type
-                      in MODEL_KEEP_MULTI_MODAL_CONTENT))
 
     for part in parts:
         parse_res = _parse_chat_message_content_part(
@@ -578,9 +774,11 @@ def _parse_chat_message_content_parts(
 
 
 def _parse_chat_message_content_part(
-        part: ChatCompletionContentPartParam,
-        mm_parser: BaseMultiModalContentParser,
-        wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
+    part: ChatCompletionContentPartParam,
+    mm_parser: BaseMultiModalContentParser,
+    *,
+    wrap_dicts: bool,
+) -> Optional[Union[str, Dict[str, str]]]:
     """Parses a single part of a conversation. If wrap_dicts is True,
     structured dictionary pieces for texts and images will be
     wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
@@ -629,7 +827,7 @@ def _parse_chat_message_content_part(
 def _parse_chat_message_content(
     message: ChatCompletionMessageParam,
     mm_tracker: BaseMultiModalItemTracker,
-    chat_template_text_format: str,
+    content_format: _ChatTemplateContentFormat,
 ) -> List[ConversationMessage]:
     role = message["role"]
     content = message.get("content")
@@ -645,7 +843,7 @@ def _parse_chat_message_content(
         role,
         content,  # type: ignore
         mm_tracker,
-        chat_template_text_format,
+        wrap_dicts=(content_format == "openai"),
     )
 
     for result_msg in result:
@@ -684,6 +882,7 @@ def parse_chat_messages(
     messages: List[ChatCompletionMessageParam],
     model_config: ModelConfig,
     tokenizer: AnyTokenizer,
+    content_format: _ChatTemplateContentFormat,
 ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
     conversation: List[ConversationMessage] = []
     mm_tracker = MultiModalItemTracker(model_config, tokenizer)
@@ -692,7 +891,7 @@ def parse_chat_messages(
         sub_messages = _parse_chat_message_content(
             msg,
             mm_tracker,
-            model_config.chat_template_text_format,
+            content_format,
         )
 
         conversation.extend(sub_messages)
@@ -706,6 +905,7 @@ def parse_chat_messages_futures(
     messages: List[ChatCompletionMessageParam],
     model_config: ModelConfig,
     tokenizer: AnyTokenizer,
+    content_format: _ChatTemplateContentFormat,
 ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
     conversation: List[ConversationMessage] = []
     mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
@@ -714,7 +914,7 @@ def parse_chat_messages_futures(
         sub_messages = _parse_chat_message_content(
             msg,
             mm_tracker,
-            model_config.chat_template_text_format,
+            content_format,
         )
 
         conversation.extend(sub_messages)
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 4b33fc1458ee3..86b0b6893f1d9 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -13,9 +13,11 @@
                                    TaskOption)
 from vllm.engine.llm_engine import LLMEngine
 from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
+                                         ChatTemplateContentFormatOption,
                                          apply_hf_chat_template,
                                          apply_mistral_chat_template,
-                                         parse_chat_messages)
+                                         parse_chat_messages,
+                                         resolve_chat_template_content_format)
 from vllm.inputs import PromptType, TextPrompt, TokensPrompt
 from vllm.inputs.parse import parse_and_batch_prompt
 from vllm.logger import init_logger
@@ -523,6 +525,7 @@ def chat(
         use_tqdm: bool = True,
         lora_request: Optional[LoRARequest] = None,
         chat_template: Optional[str] = None,
+        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
         add_generation_prompt: bool = True,
         continue_final_message: bool = False,
         tools: Optional[List[Dict[str, Any]]] = None,
@@ -539,9 +542,11 @@ def chat(
         to the OpenAI API.
 
         Args:
-            messages: A list of conversations or a single conversation. 
-                - Each conversation is represented as a list of messages.
-                - Each message is a dictionary with 'role' and 'content' keys.
+            messages: A list of conversations or a single conversation.
+
+              - Each conversation is represented as a list of messages.
+              - Each message is a dictionary with 'role' and 'content' keys.
+
             sampling_params: The sampling parameters for text generation.
                 If None, we use the default sampling parameters. When it
                 is a single value, it is applied to every prompt. When it
@@ -551,11 +556,19 @@ def chat(
             lora_request: LoRA request to use for generation, if any.
             chat_template: The template to use for structuring the chat.
               If not provided, the model's default chat template will be used.
+            chat_template_content_format: The format to render message content.
+
+              - "string" will render the content as a string.
+                Example: ``"Who are you?"``
+              - "openai" will render the content as a list of dictionaries,
+                similar to OpenAI schema.
+                Example: ``[{"type": "text", "text": "Who are you?"}]``
+
             add_generation_prompt: If True, adds a generation template
                 to each message.
             continue_final_message: If True, continues the final message in
-                the conversation instead of starting a new one. Cannot be `True`
-                if `add_generation_prompt` is also `True`.
+                the conversation instead of starting a new one. Cannot be
+                ``True`` if ``add_generation_prompt`` is also ``True``.
             mm_processor_kwargs: Multimodal processor kwarg overrides for this
                 chat request. Only used for offline requests.
 
@@ -576,17 +589,26 @@ def chat(
                 cast(List[ChatCompletionMessageParam], messages)
             ]
 
+        tokenizer = self.get_tokenizer()
+        model_config = self.llm_engine.get_model_config()
+        resolved_content_format = resolve_chat_template_content_format(
+            chat_template,
+            chat_template_content_format,
+            tokenizer,
+        )
+
         prompts: List[Union[TokensPrompt, TextPrompt]] = []
 
         for msgs in list_of_messages:
-            tokenizer = self.get_tokenizer()
-            model_config = self.llm_engine.get_model_config()
-
             # NOTE: _parse_chat_message_content_parts() currently doesn't
             # handle mm_processor_kwargs, since there is no implementation in
             # the chat message parsing for it.
             conversation, mm_data = parse_chat_messages(
-                msgs, model_config, tokenizer)
+                msgs,
+                model_config,
+                tokenizer,
+                content_format=resolved_content_format,
+            )
 
             prompt_data: Union[str, List[int]]
             if isinstance(tokenizer, MistralTokenizer):
@@ -737,7 +759,7 @@ def encode(
                 generation, if any.
 
         Returns:
-            A list of `EmbeddingRequestOutput` objects containing the
+            A list of ``EmbeddingRequestOutput`` objects containing the
             generated embeddings in the same order as the input prompts.
 
         Note:
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index b13f6a228b4c6..b0fe061f5db4a 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -29,6 +29,7 @@
 from vllm.engine.multiprocessing.client import MQLLMEngineClient
 from vllm.engine.multiprocessing.engine import run_mp_engine
 from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.chat_utils import load_chat_template
 from vllm.entrypoints.launcher import serve_http
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.cli_args import (make_arg_parser,
@@ -529,6 +530,9 @@ def init_app_state(
     state.engine_client = engine_client
     state.log_stats = not args.disable_log_stats
 
+    resolved_chat_template = load_chat_template(args.chat_template)
+    logger.info("Using supplied chat template:\n%s", resolved_chat_template)
+
     state.openai_serving_chat = OpenAIServingChat(
         engine_client,
         model_config,
@@ -537,7 +541,8 @@ def init_app_state(
         lora_modules=args.lora_modules,
         prompt_adapters=args.prompt_adapters,
         request_logger=request_logger,
-        chat_template=args.chat_template,
+        chat_template=resolved_chat_template,
+        chat_template_content_format=args.chat_template_content_format,
         return_tokens_as_token_ids=args.return_tokens_as_token_ids,
         enable_auto_tools=args.enable_auto_tool_choice,
         tool_parser=args.tool_call_parser,
@@ -557,7 +562,8 @@ def init_app_state(
         model_config,
         base_model_paths,
         request_logger=request_logger,
-        chat_template=args.chat_template,
+        chat_template=resolved_chat_template,
+        chat_template_content_format=args.chat_template_content_format,
     ) if model_config.task == "embedding" else None
     state.openai_serving_tokenization = OpenAIServingTokenization(
         engine_client,
@@ -565,7 +571,8 @@ def init_app_state(
         base_model_paths,
         lora_modules=args.lora_modules,
         request_logger=request_logger,
-        chat_template=args.chat_template,
+        chat_template=resolved_chat_template,
+        chat_template_content_format=args.chat_template_content_format,
     )
 
 
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index eb08a89293370..24c206a1261f2 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -7,10 +7,11 @@
 import argparse
 import json
 import ssl
-from typing import List, Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union, get_args
 
 from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
-from vllm.entrypoints.chat_utils import validate_chat_template
+from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
+                                         validate_chat_template)
 from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                     PromptAdapterPath)
 from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -132,6 +133,18 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
                         help="The file path to the chat template, "
                         "or the template in single-line form "
                         "for the specified model")
+    parser.add_argument(
+        '--chat-template-content-format',
+        type=str,
+        default="auto",
+        choices=get_args(ChatTemplateContentFormatOption),
+        help='The format to render message content within a chat template.'
+        '\n\n'
+        '* "string" will render the content as a string. '
+        'Example: "Hello World"\n'
+        '* "openai" will render the content as a list of dictionaries, '
+        'similar to OpenAI schema. '
+        'Example: [{"type": "text", "text": "Hello world!"}]')
     parser.add_argument("--response-role",
                         type=nullable_str,
                         default="assistant",
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 820aefd8800d9..b7b064ae01f05 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -5,9 +5,8 @@
 from typing import Any, Dict, List, Literal, Optional, Union
 
 import torch
-from openai.types.chat import ChatCompletionContentPartParam
 from pydantic import BaseModel, ConfigDict, Field, model_validator
-from typing_extensions import Annotated, Required, TypedDict
+from typing_extensions import Annotated
 
 from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
 from vllm.pooling_params import PoolingParams
@@ -35,26 +34,6 @@
 assert _LONG_INFO.max == _MOCK_LONG_INFO.max
 
 
-class CustomChatCompletionMessageParam(TypedDict, total=False):
-    """Enables custom roles in the Chat Completion API."""
-    role: Required[str]
-    """The role of the message's author."""
-
-    content: Union[str, List[ChatCompletionContentPartParam]]
-    """The contents of the message."""
-
-    name: str
-    """An optional name for the participant.
-
-    Provides the model information to differentiate between participants of the
-    same role.
-    """
-
-    tool_call_id: Optional[str]
-
-    tool_calls: Optional[List[dict]]
-
-
 class OpenAIBaseModel(BaseModel):
     # OpenAI API does not allow extra fields
     model_config = ConfigDict(extra="forbid")
@@ -1054,16 +1033,56 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
     model: str
     prompt: str
 
-    add_special_tokens: bool = Field(default=True)
+    add_special_tokens: bool = Field(
+        default=True,
+        description=(
+            "If true (the default), special tokens (e.g. BOS) will be added to "
+            "the prompt."),
+    )
 
 
 class TokenizeChatRequest(OpenAIBaseModel):
     model: str
     messages: List[ChatCompletionMessageParam]
 
-    add_generation_prompt: bool = Field(default=True)
-    continue_final_message: bool = Field(default=False)
-    add_special_tokens: bool = Field(default=False)
+    add_generation_prompt: bool = Field(
+        default=True,
+        description=
+        ("If true, the generation prompt will be added to the chat template. "
+         "This is a parameter used by chat template in tokenizer config of the "
+         "model."),
+    )
+    continue_final_message: bool = Field(
+        default=False,
+        description=
+        ("If this is set, the chat will be formatted so that the final "
+         "message in the chat is open-ended, without any EOS tokens. The "
+         "model will continue this message rather than starting a new one. "
+         "This allows you to \"prefill\" part of the model's response for it. "
+         "Cannot be used at the same time as `add_generation_prompt`."),
+    )
+    add_special_tokens: bool = Field(
+        default=False,
+        description=(
+            "If true, special tokens (e.g. BOS) will be added to the prompt "
+            "on top of what is added by the chat template. "
+            "For most models, the chat template takes care of adding the "
+            "special tokens so this should be set to false (as is the "
+            "default)."),
+    )
+    chat_template: Optional[str] = Field(
+        default=None,
+        description=(
+            "A Jinja template to use for this conversion. "
+            "As of transformers v4.44, default chat template is no longer "
+            "allowed, so you must provide a chat template if the tokenizer "
+            "does not define one."),
+    )
+    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
+        default=None,
+        description=("Additional kwargs to pass to the template renderer. "
+                     "Will be accessible by the chat template."),
+    )
 
     @model_validator(mode="before")
     @classmethod
diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py
index 1b422a93263b2..00cdb3b6839f5 100644
--- a/vllm/entrypoints/openai/run_batch.py
+++ b/vllm/entrypoints/openai/run_batch.py
@@ -222,6 +222,7 @@ async def main(args):
         prompt_adapters=None,
         request_logger=request_logger,
         chat_template=None,
+        chat_template_content_format="auto",
         enable_prompt_tokens_details=args.enable_prompt_tokens_details,
     ) if model_config.task == "generate" else None
     openai_serving_embedding = OpenAIServingEmbedding(
@@ -230,6 +231,7 @@ async def main(args):
         base_model_paths,
         request_logger=request_logger,
         chat_template=None,
+        chat_template_content_format="auto",
     ) if model_config.task == "embedding" else None
 
     tracker = BatchProgressTracker()
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 77cae00ae827f..2eef909eb9319 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -10,7 +10,8 @@
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
+from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
+                                         ConversationMessage)
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.protocol import (
     ChatCompletionLogProb, ChatCompletionLogProbs,
@@ -38,20 +39,23 @@
 
 class OpenAIServingChat(OpenAIServing):
 
-    def __init__(self,
-                 engine_client: EngineClient,
-                 model_config: ModelConfig,
-                 base_model_paths: List[BaseModelPath],
-                 response_role: str,
-                 *,
-                 lora_modules: Optional[List[LoRAModulePath]],
-                 prompt_adapters: Optional[List[PromptAdapterPath]],
-                 request_logger: Optional[RequestLogger],
-                 chat_template: Optional[str],
-                 return_tokens_as_token_ids: bool = False,
-                 enable_auto_tools: bool = False,
-                 tool_parser: Optional[str] = None,
-                 enable_prompt_tokens_details: bool = False):
+    def __init__(
+        self,
+        engine_client: EngineClient,
+        model_config: ModelConfig,
+        base_model_paths: List[BaseModelPath],
+        response_role: str,
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
+        chat_template: Optional[str],
+        chat_template_content_format: ChatTemplateContentFormatOption,
+        return_tokens_as_token_ids: bool = False,
+        enable_auto_tools: bool = False,
+        tool_parser: Optional[str] = None,
+        enable_prompt_tokens_details: bool = False,
+    ) -> None:
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
                          base_model_paths=base_model_paths,
@@ -61,8 +65,8 @@ def __init__(self,
                          return_tokens_as_token_ids=return_tokens_as_token_ids)
 
         self.response_role = response_role
-        self.use_tool_use_model_template = False
-        self.chat_template = load_chat_template(chat_template)
+        self.chat_template = chat_template
+        self.chat_template_content_format: Final = chat_template_content_format
 
         # set up tool use
         self.enable_auto_tools: bool = enable_auto_tools
@@ -120,6 +124,7 @@ async def create_chat_completion(
             ) = self._maybe_get_adapters(request)
 
             tokenizer = await self.engine_client.get_tokenizer(lora_request)
+
             tool_parser = self.tool_parser
 
             # validation for OpenAI tools
@@ -157,6 +162,7 @@ async def create_chat_completion(
                 tokenizer,
                 request.messages,
                 chat_template=request.chat_template or self.chat_template,
+                chat_template_content_format=self.chat_template_content_format,
                 add_generation_prompt=request.add_generation_prompt,
                 continue_final_message=request.continue_final_message,
                 tool_dicts=tool_dicts,
diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py
index bbe7db8f13231..74ad7389784fc 100644
--- a/vllm/entrypoints/openai/serving_embedding.py
+++ b/vllm/entrypoints/openai/serving_embedding.py
@@ -1,7 +1,7 @@
 import asyncio
 import base64
 import time
-from typing import AsyncGenerator, List, Literal, Optional, Union, cast
+from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
 
 import numpy as np
 from fastapi import Request
@@ -9,7 +9,7 @@
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import load_chat_template
+from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
                                               EmbeddingRequest,
@@ -77,7 +77,8 @@ def __init__(
         *,
         request_logger: Optional[RequestLogger],
         chat_template: Optional[str],
-    ):
+        chat_template_content_format: ChatTemplateContentFormatOption,
+    ) -> None:
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
                          base_model_paths=base_model_paths,
@@ -85,7 +86,8 @@ def __init__(
                          prompt_adapters=None,
                          request_logger=request_logger)
 
-        self.chat_template = load_chat_template(chat_template)
+        self.chat_template = chat_template
+        self.chat_template_content_format: Final = chat_template_content_format
 
     async def create_embedding(
         self,
@@ -144,6 +146,8 @@ async def create_embedding(
                     tokenizer,
                     request.messages,
                     chat_template=request.chat_template or self.chat_template,
+                    chat_template_content_format=self.
+                    chat_template_content_format,
                     add_generation_prompt=request.add_generation_prompt,
                     continue_final_message=request.continue_final_message,
                     truncate_prompt_tokens=truncate_prompt_tokens,
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index fa315fa516632..cae2877ea7e99 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -11,14 +11,16 @@
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
+# yapf conflicts with isort for this block
+# yapf: disable
 from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
+                                         ChatTemplateContentFormatOption,
                                          ConversationMessage,
                                          apply_hf_chat_template,
                                          apply_mistral_chat_template,
-                                         parse_chat_messages_futures)
+                                         parse_chat_messages_futures,
+                                         resolve_chat_template_content_format)
 from vllm.entrypoints.logger import RequestLogger
-# yapf conflicts with isort for this block
-# yapf: disable
 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
                                               CompletionRequest,
                                               DetokenizeRequest,
@@ -426,7 +428,8 @@ async def _preprocess_chat(
         request: ChatLikeRequest,
         tokenizer: AnyTokenizer,
         messages: List[ChatCompletionMessageParam],
-        chat_template: Optional[str] = None,
+        chat_template: Optional[str],
+        chat_template_content_format: ChatTemplateContentFormatOption,
         add_generation_prompt: bool = True,
         continue_final_message: bool = False,
         tool_dicts: Optional[List[Dict[str, Any]]] = None,
@@ -437,10 +440,16 @@ async def _preprocess_chat(
         add_special_tokens: bool = False,
     ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
                List[TokensPrompt]]:
+        resolved_content_format = resolve_chat_template_content_format(
+            chat_template,
+            chat_template_content_format,
+            tokenizer,
+        )
         conversation, mm_data_future = parse_chat_messages_futures(
             messages,
             self.model_config,
             tokenizer,
+            content_format=resolved_content_format,
         )
 
         _chat_template_kwargs: Dict[str, Any] = dict(
diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py
index 1fd82304f7a4d..59b3b1311f881 100644
--- a/vllm/entrypoints/openai/serving_tokenization.py
+++ b/vllm/entrypoints/openai/serving_tokenization.py
@@ -1,8 +1,8 @@
-from typing import List, Optional, Union
+from typing import Final, List, Optional, Union
 
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import load_chat_template
+from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
 from vllm.entrypoints.logger import RequestLogger
 # yapf conflicts with isort for this block
 # yapf: disable
@@ -33,7 +33,8 @@ def __init__(
         lora_modules: Optional[List[LoRAModulePath]],
         request_logger: Optional[RequestLogger],
         chat_template: Optional[str],
-    ):
+        chat_template_content_format: ChatTemplateContentFormatOption,
+    ) -> None:
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
                          base_model_paths=base_model_paths,
@@ -41,12 +42,8 @@ def __init__(
                          prompt_adapters=None,
                          request_logger=request_logger)
 
-        # If this is None we use the tokenizer's default chat template
-        # the list of commonly-used chat template names for HF named templates
-        hf_chat_templates: List[str] = ['default', 'tool_use']
-        self.chat_template = chat_template \
-            if chat_template in hf_chat_templates \
-            else load_chat_template(chat_template)
+        self.chat_template = chat_template
+        self.chat_template_content_format: Final = chat_template_content_format
 
     async def create_tokenize(
         self,
@@ -75,9 +72,12 @@ async def create_tokenize(
                     request,
                     tokenizer,
                     request.messages,
-                    chat_template=self.chat_template,
+                    chat_template=request.chat_template or self.chat_template,
+                    chat_template_content_format=self.
+                    chat_template_content_format,
                     add_generation_prompt=request.add_generation_prompt,
                     continue_final_message=request.continue_final_message,
+                    chat_template_kwargs=request.chat_template_kwargs,
                     add_special_tokens=request.add_special_tokens,
                 )
             else: