Skip to content

Commit

Permalink
Use Model to generate query
Browse files Browse the repository at this point in the history
  • Loading branch information
showpune committed Jun 17, 2024
1 parent 4c296c3 commit 5a77804
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.springframework.samples.petclinic.chat;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.vectorstore.VectorStore;
Expand All @@ -27,8 +26,6 @@ public class Agent {
@Autowired
private ChatClient chatClient;

@Autowired
private ChatModel chatModel;
@Autowired
private VectorStore vectorStore;
@Value("classpath:/prompts/system-message.st")
Expand All @@ -37,7 +34,6 @@ public class Agent {
public String chat(String userMessage, String username) {

try {
String processedMessage = chatModel.call(TRANSLATE + "\n" + userMessage);

Consumer<ChatClient.AdvisorSpec> advisorSpecConsumer = advisorSpec -> {
advisorSpec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, username);
Expand All @@ -52,7 +48,7 @@ public String chat(String userMessage, String username) {
//userName as memory key
.advisors(advisorSpecConsumer)
.system(systemPromptTemplate.render(systemParameters))
.user(processedMessage)
.user(userMessage)
.functions("queryOwners", "addOwner", "updateOwner", "queryVets")
.call()
.content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.reader.TextReader;
Expand All @@ -31,9 +31,9 @@ public ChatClient chatClient(ChatClient.Builder chatClientBuilder) {
}

@Bean
public ChatClientCustomizer chatClientCustomizer(VectorStore vectorStore) {
public ChatClientCustomizer chatClientCustomizer(VectorStore vectorStore, ChatModel model) {
ChatMemory chatMemory = new InMemoryChatMemory();
return b -> b.defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory), new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()));
return b -> b.defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory), new ModeledQuestionAnswerAdvisor(vectorStore, SearchRequest.defaults(), model));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package org.springframework.samples.petclinic.chat;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;

import java.util.Map;

public class ModeledQuestionAnswerAdvisor extends QuestionAnswerAdvisor {
private static final String TRANSLATE = "Generate 1 different versions of a provided user query. " +
"but they should all retain the original meaning. " +
"It will be used to retrieve relevant documents and it should be in English \n" +
"Without enumerations, hyphens, or any additional formatting!";

private ChatModel chatModel;

public ModeledQuestionAnswerAdvisor(VectorStore vectorStore, ChatModel chatModel, String modeledText) {
super(vectorStore);
this.chatModel = chatModel;
}

public ModeledQuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, ChatModel chatModel) {
super(vectorStore, searchRequest);
this.chatModel = chatModel;
}

public ModeledQuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, ChatModel chatModel) {
super(vectorStore, searchRequest, userTextAdvise);
this.chatModel = chatModel;
}

@Override
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
String originalUserText = request.userText();
String processedMessage = chatModel.call(TRANSLATE + "\n" + request.userText());
AdvisedRequest processedRequest = AdvisedRequest.from(request).withUserText(processedMessage).build();
request = super.adviseRequest(processedRequest, context);
return AdvisedRequest.from(request).withUserText(originalUserText).build();
}
}

0 comments on commit 5a77804

Please sign in to comment.