From 5a77804f0ac89692cfd783cfffd002dd103f59f2 Mon Sep 17 00:00:00 2001 From: showpune Date: Tue, 18 Jun 2024 07:21:47 +0800 Subject: [PATCH] Use Model to generate query --- .../samples/petclinic/chat/Agent.java | 6 +-- .../samples/petclinic/chat/AgentConfig.java | 6 +-- .../chat/ModeledQuestionAnswerAdvisor.java | 42 +++++++++++++++++++ 3 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 src/main/java/org/springframework/samples/petclinic/chat/ModeledQuestionAnswerAdvisor.java diff --git a/src/main/java/org/springframework/samples/petclinic/chat/Agent.java b/src/main/java/org/springframework/samples/petclinic/chat/Agent.java index 219bffd..037fc02 100644 --- a/src/main/java/org/springframework/samples/petclinic/chat/Agent.java +++ b/src/main/java/org/springframework/samples/petclinic/chat/Agent.java @@ -1,7 +1,6 @@ package org.springframework.samples.petclinic.chat; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.vectorstore.VectorStore; @@ -27,8 +26,6 @@ public class Agent { @Autowired private ChatClient chatClient; - @Autowired - private ChatModel chatModel; @Autowired private VectorStore vectorStore; @Value("classpath:/prompts/system-message.st") @@ -37,7 +34,6 @@ public class Agent { public String chat(String userMessage, String username) { try { - String processedMessage = chatModel.call(TRANSLATE + "\n" + userMessage); Consumer advisorSpecConsumer = advisorSpec -> { advisorSpec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, username); @@ -52,7 +48,7 @@ public String chat(String userMessage, String username) { //userName as memory key .advisors(advisorSpecConsumer) .system(systemPromptTemplate.render(systemParameters)) - .user(processedMessage) + .user(userMessage) .functions("queryOwners", "addOwner", "updateOwner", "queryVets") .call() .content(); diff --git a/src/main/java/org/springframework/samples/petclinic/chat/AgentConfig.java b/src/main/java/org/springframework/samples/petclinic/chat/AgentConfig.java index 9fb5840..c90f19a 100644 --- a/src/main/java/org/springframework/samples/petclinic/chat/AgentConfig.java +++ b/src/main/java/org/springframework/samples/petclinic/chat/AgentConfig.java @@ -4,9 +4,9 @@ import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClientCustomizer; import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; -import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemory; +import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.reader.TextReader; @@ -31,9 +31,9 @@ public ChatClient chatClient(ChatClient.Builder chatClientBuilder) { } @Bean - public ChatClientCustomizer chatClientCustomizer(VectorStore vectorStore) { + public ChatClientCustomizer chatClientCustomizer(VectorStore vectorStore, ChatModel model) { ChatMemory chatMemory = new InMemoryChatMemory(); - return b -> b.defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory), new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())); + return b -> b.defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory), new ModeledQuestionAnswerAdvisor(vectorStore, SearchRequest.defaults(), model)); } @Bean diff --git a/src/main/java/org/springframework/samples/petclinic/chat/ModeledQuestionAnswerAdvisor.java b/src/main/java/org/springframework/samples/petclinic/chat/ModeledQuestionAnswerAdvisor.java new file mode 100644 index 0000000..257b5d4 --- /dev/null +++ b/src/main/java/org/springframework/samples/petclinic/chat/ModeledQuestionAnswerAdvisor.java @@ -0,0 +1,42 @@ +package org.springframework.samples.petclinic.chat; + +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; + +import java.util.Map; + +public class ModeledQuestionAnswerAdvisor extends QuestionAnswerAdvisor { + private static final String TRANSLATE = "Generate 1 different versions of a provided user query. " + + "but they should all retain the original meaning. " + + "It will be used to retrieve relevant documents and it should be in English \n" + + "Without enumerations, hyphens, or any additional formatting!"; + + private ChatModel chatModel; + + public ModeledQuestionAnswerAdvisor(VectorStore vectorStore, ChatModel chatModel, String modeledText) { + super(vectorStore); + this.chatModel = chatModel; + } + + public ModeledQuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, ChatModel chatModel) { + super(vectorStore, searchRequest); + this.chatModel = chatModel; + } + + public ModeledQuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, ChatModel chatModel) { + super(vectorStore, searchRequest, userTextAdvise); + this.chatModel = chatModel; + } + + @Override + public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + String originalUserText = request.userText(); + String processedMessage = chatModel.call(TRANSLATE + "\n" + request.userText()); + AdvisedRequest processedRequest = AdvisedRequest.from(request).withUserText(processedMessage).build(); + request = super.adviseRequest(processedRequest, context); + return AdvisedRequest.from(request).withUserText(originalUserText).build(); + } +}