diff --git a/src/routes/openai.py b/src/routes/openai.py index ba2d1b7..2fffc01 100644 --- a/src/routes/openai.py +++ b/src/routes/openai.py @@ -3,7 +3,7 @@ from fastapi import APIRouter from openai.types.chat import ChatCompletionContentPartTextParam from promplate import Message -from pydantic import field_serializer +from pydantic import field_validator from typing_extensions import TypedDict from ..utils.llm import Model @@ -47,12 +47,13 @@ class ChatInput(ChainInput): stream: bool = False messages: list[CompatibleMessage] # type: ignore - @field_serializer("messages") - def serialize_messages(self, value: CompatibleMessage): - content = value["content"] - if isinstance(content, str): - return value - value["content"] = "".join(i["text"] for i in content) + @field_validator("messages", mode="after") + def serialize_messages(cls, value: list[CompatibleMessage]): + for msg in value: + content = msg["content"] + if isinstance(content, str): + continue + msg["content"] = "".join(i["text"] for i in content) return value @property