Skip to content

Commit

Permalink
[Bug][Assistant API] - chat/complete endpoint is not persisting the m…
Browse files Browse the repository at this point in the history
…odel response to the chosen conversation ID (elastic#11783) (elastic#212122)

## Summary

BUG: elastic/security-team#11783

This PR fixes the behaviour of the
`/api/security_ai_assistant/chat/complete` route where the `persist`
flag:
1. when set to `true` does not append the assistant reply to existing
conversation
2. when set to `false` appends user message to existing conversation

### Expected behaviour


[Details](elastic/security-team#11783 (comment)).

1. `conversationId == undefined && persist == false`: no new
conversations and nothing persisted
2. `conversationId == undefined && persist == true`: new conversations
is created and both user message and assistant reply appended to the new
conversation
3. `conversationId == 'existing-id' && persist == false`: nothing
appended to the existing conversation
4. `conversationId == 'existing-id' && persist == true`: both user
message and assistant reply appended to the existing conversation

### Testing

* Use this `curl` command (with replace `connectorId` and
`conversationId`) to test the endpoint.

```
curl --location 'http://localhost:5601/api/security_ai_assistant/chat/complete' \
--header 'kbn-xsrf: true' \
--header 'Content-Type: application/json' \
--data '{
  "connectorId": "{{my-gpt4o-ai}}",
  "conversationId": "{{existing-conversation-id | undefined}}",
  "isStream": false,
  "messages": [
    {
      "content": "Follow up",
      "role": "user"
    }
  ],
  "persist": true
}'
```

* To retrieve the conversation ID:
(/api/security_ai_assistant/current_user/conversations/_find)
* `conversationId` can be either existing conversation id or `undefined`
  • Loading branch information
e40pud authored Feb 26, 2025
1 parent 0121f4b commit a2b2e81
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jest.mock('../helpers', () => {
};
});
const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock;
const mockCreateConversationWithUserInput = createConversationWithUserInput as jest.Mock;

const mockLangChainExecute = langChainExecute as jest.Mock;
const mockStream = jest.fn().mockImplementation(() => new PassThrough());
Expand Down Expand Up @@ -150,7 +151,7 @@ describe('chatCompleteRoute', () => {
jest.clearAllMocks();
mockAppendAssistantMessageToConversation.mockResolvedValue(true);
license.hasAtLeast.mockReturnValue(true);
(createConversationWithUserInput as jest.Mock).mockResolvedValue({ id: 'something' });
mockCreateConversationWithUserInput.mockResolvedValue({ id: 'something' });
mockLangChainExecute.mockImplementation(
async ({
connectorId,
Expand All @@ -166,12 +167,14 @@ describe('chatCompleteRoute', () => {
) => Promise<void>;
}) => {
if (!isStream && connectorId === 'mock-connector-id') {
onLlmResponse('Non-streamed test reply.', {}, false).catch(() => {});
return {
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
};
} else if (isStream && connectorId === 'mock-connector-id') {
onLlmResponse('Streamed test reply.', {}, false).catch(() => {});
return mockStream;
} else {
onLlmResponse('simulated error', {}, true).catch(() => {});
Expand Down Expand Up @@ -399,4 +402,141 @@ describe('chatCompleteRoute', () => {
mockGetElser
);
});

it('should add assistant reply to existing conversation when `persist=true`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: existingConversation.id,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
expect.objectContaining({
messageContent: 'Non-streamed test reply.',
isError: false,
})
);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('should not add assistant reply to existing conversation when `persist=false`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: existingConversation.id,
persist: false,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('should add assistant reply to new conversation when `persist=true`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: undefined,
persist: true,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
expect.objectContaining({
messageContent: 'Non-streamed test reply.',
isError: false,
})
);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(1);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('should not create a new conversation when `persist=false`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: undefined,
persist: false,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export const chatCompleteRoute = (
await ctx.elasticAssistant.getAIAssistantAnonymizationFieldsDataClient();

let messages;
const conversationId = request.body.conversationId;
const existingConversationId = request.body.conversationId;
const connectorId = request.body.connectorId;

let latestReplacements: Replacements = {};
Expand Down Expand Up @@ -159,11 +159,10 @@ export const chatCompleteRoute = (
});

let newConversation: ConversationResponse | undefined | null;
if (conversationsDataClient && !conversationId && request.body.persist) {
if (conversationsDataClient && !existingConversationId && request.body.persist) {
newConversation = await createConversationWithUserInput({
actionTypeId,
connectorId,
conversationId,
conversationsDataClient,
promptId: request.body.promptId,
replacements: latestReplacements,
Expand All @@ -178,18 +177,23 @@ export const chatCompleteRoute = (
}));
}

// Do not persist conversation messages if `persist = false`
const conversationId = request.body.persist
? existingConversationId ?? newConversation?.id
: undefined;

const contentReferencesStore = newContentReferencesStore();

const onLlmResponse = async (
content: string,
traceData: Message['traceData'] = {},
isError = false
): Promise<void> => {
if (newConversation?.id && conversationsDataClient) {
if (conversationId && conversationsDataClient) {
const contentReferences = pruneContentReferences(content, contentReferencesStore);

await appendAssistantMessageToConversation({
conversationId: newConversation?.id,
conversationId,
conversationsDataClient,
messageContent: content,
replacements: latestReplacements,
Expand All @@ -207,7 +211,7 @@ export const chatCompleteRoute = (
actionTypeId,
connectorId,
isOssModel,
conversationId: conversationId ?? newConversation?.id,
conversationId,
context: ctx,
getElser,
logger,
Expand Down

0 comments on commit a2b2e81

Please sign in to comment.