Skip to content

Commit

Permalink
[Security AI] Add tool and chat title prompts to security-ai-prompts (
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Feb 3, 2025
1 parent 9891bbd commit 092c2cd
Show file tree
Hide file tree
Showing 18 changed files with 345 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { invokeGraph, streamGraph } from './helpers';
import { loggerMock } from '@kbn/logging-mocks';
import { AgentExecutorParams, AssistantDataClients } from '../../executors/types';
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
import { getPrompt, resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { getFindAnonymizationFieldsResultWithSingleHit } from '../../../../__mocks__/response';
import {
createOpenAIToolsAgent,
Expand All @@ -20,7 +21,10 @@ import {
} from 'langchain/agents';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { AssistantTool, AssistantToolParams } from '../../../..';
import { promptGroupId as toolsGroupId } from '../../../prompt/tool_prompts';
import { promptDictionary } from '../../../prompt';
import { promptGroupId } from '../../../prompt/local_prompt_object';
jest.mock('./graph');
jest.mock('./helpers');
jest.mock('langchain/agents');
Expand All @@ -29,6 +33,7 @@ jest.mock('@kbn/langchain/server/tracers/telemetry');
jest.mock('@kbn/security-ai-prompts');
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
const resolveProviderAndModelMock = resolveProviderAndModel as jest.Mock;
const getPromptMock = getPrompt as jest.Mock;
describe('callAssistantGraph', () => {
const mockDataClients = {
anonymizationFieldsDataClient: {
Expand Down Expand Up @@ -98,6 +103,7 @@ describe('callAssistantGraph', () => {
(mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockResolvedValue(
getFindAnonymizationFieldsResultWithSingleHit()
);
getPromptMock.mockResolvedValue('prompt');
});

it('calls invokeGraph with correct parameters for non-streaming', async () => {
Expand Down Expand Up @@ -173,6 +179,58 @@ describe('callAssistantGraph', () => {
});
});

it('calls getPrompt for each tool and the default system prompt', async () => {
const getTool = jest.fn();
const mockTool: AssistantTool = {
id: 'id',
name: 'name',
description: 'description',
sourceRegister: 'sourceRegister',
isSupported: (params: AssistantToolParams) => true,
getTool,
};
const params = {
...defaultParams,
assistantTools: [
{ ...mockTool, name: 'test-tool' },
{ ...mockTool, name: 'test-tool2' },
],
};
await callAssistantGraph(params);

expect(getPromptMock).toHaveBeenCalledTimes(3);
expect(getPromptMock).toHaveBeenCalledWith(
expect.objectContaining({
model: 'test-model',
provider: 'openai',
promptId: 'test-tool',
promptGroupId: toolsGroupId,
})
);
expect(getPromptMock).toHaveBeenCalledWith(
expect.objectContaining({
model: 'test-model',
provider: 'openai',
promptId: 'test-tool2',
promptGroupId: toolsGroupId,
})
);
expect(getPromptMock).toHaveBeenCalledWith(
expect.objectContaining({
model: 'test-model',
provider: 'openai',
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
})
);

expect(getTool).toHaveBeenCalledWith(
expect.objectContaining({
description: 'prompt',
})
);
});

describe('agentRunnable', () => {
it('creates OpenAIToolsAgent for openai llmType', async () => {
const params = { ...defaultParams, llmType: 'openai' };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import {
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { getPrompt, resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { localToolPrompts, promptGroupId as toolsGroupId } from '../../../prompt/tool_prompts';
import { promptGroupId } from '../../../prompt/local_prompt_object';
import { getModelOrOss } from '../../../prompt/helpers';
import { getPrompt, promptDictionary } from '../../../prompt';
import { getPrompt as localGetPrompt, promptDictionary } from '../../../prompt';
import { getLlmClass } from '../../../../routes/utils';
import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types';
import { AssistantToolParams } from '../../../../types';
Expand Down Expand Up @@ -124,9 +125,33 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
telemetry,
};

const tools: StructuredTool[] = assistantTools.flatMap(
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? []
);
const tools: StructuredTool[] = (
await Promise.all(
assistantTools.map(async (tool) => {
let description: string | undefined;
try {
description = await getPrompt({
actionsClient,
connectorId,
localPrompts: localToolPrompts,
model: getModelOrOss(llmType, isOssModel, request.body.model),
promptId: tool.name,
promptGroupId: toolsGroupId,
provider: llmType,
savedObjectsClient,
});
} catch (e) {
logger.error(`Failed to get prompt for tool: ${tool.name}`);
}
return tool.getTool({
...assistantToolParams,
llm: createLlmInstance(),
isOssModel,
description,
});
})
)
).filter((e) => e != null) as StructuredTool[];

// If KB enabled, fetch for any KB IndexEntries and generate a tool for each
if (isEnabledKnowledgeBase) {
Expand All @@ -139,7 +164,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
}
}

const defaultSystemPrompt = await getPrompt({
const defaultSystemPrompt = await localGetPrompt({
actionsClient,
connectorId,
model: getModelOrOss(llmType, isOssModel, request.body.model),
Expand Down Expand Up @@ -176,7 +201,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
const telemetryTracer = telemetryParams
? new TelemetryTracer(
{
elasticTools: assistantTools.map(({ name }) => name),
elasticTools: tools.map(({ name }) => name),
totalTools: tools.length,
telemetry,
telemetryParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,32 @@ import { StringOutputParser } from '@langchain/core/output_parsers';

import { ChatPromptTemplate } from '@langchain/core/prompts';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getPrompt, promptDictionary } from '../../../../prompt';
import { AgentState, NodeParamsBase } from '../types';
import { NodeType } from '../constants';
import { promptGroupId } from '../../../../prompt/local_prompt_object';

export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string, llmType?: string) =>
llmType === 'bedrock'
? ChatPromptTemplate.fromMessages([
[
'system',
`You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}. Respond with the title only with no other text explaining your response. As an example, for the given MESSAGE, this is the TITLE:
MESSAGE: I am having trouble with the Elastic Security app.
TITLE: Troubleshooting Elastic Security app issues
`,
],
['human', '{input}'],
])
: llmType === 'gemini'
? ChatPromptTemplate.fromMessages([
[
'system',
`You are a title generator for a helpful assistant for Elastic Security. Assume the following human message is the start of a conversation between you and a human. Generate a relevant conversation title for the human's message in plain text. Make sure the title is formatted for the user, without using quotes or markdown. The title should clearly reflect the content of the message and be appropriate for a list of conversations. Please create the title in ${responseLanguage}. Respond only with the title. As an example, for the given MESSAGE, this is the TITLE:
MESSAGE: I am having trouble with the Elastic Security app.
TITLE: Troubleshooting Elastic Security app issues
`,
],
['human', '{input}'],
])
: ChatPromptTemplate.fromMessages([
[
'system',
`You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}. As an example, for the given MESSAGE, this is the TITLE:
MESSAGE: I am having trouble with the Elastic Security app.
TITLE: Troubleshooting Elastic Security app issues
`,
],
['human', '{input}'],
]);
export const GENERATE_CHAT_TITLE_PROMPT = ({
prompt,
responseLanguage,
}: {
prompt: string;
responseLanguage: string;
}) =>
ChatPromptTemplate.fromMessages([
['system', `${prompt}\nPlease create the title in ${responseLanguage}.`],
['human', '{input}'],
]);

export interface GenerateChatTitleParams extends NodeParamsBase {
state: AgentState;
model: BaseChatModel;
}

export async function generateChatTitle({
actionsClient,
logger,
savedObjectsClient,
state,
model,
}: GenerateChatTitleParams): Promise<Partial<AgentState>> {
Expand All @@ -64,7 +43,15 @@ export async function generateChatTitle({
);

const outputParser = new StringOutputParser();
const graph = GENERATE_CHAT_TITLE_PROMPT(state.responseLanguage, state.llmType)
const prompt = await getPrompt({
actionsClient,
connectorId: state.connectorId,
promptId: promptDictionary.chatTitle,
promptGroupId: promptGroupId.aiAssistant,
provider: state.llmType,
savedObjectsClient,
});
const graph = GENERATE_CHAT_TITLE_PROMPT({ prompt, responseLanguage: state.responseLanguage })
.pipe(model)
.pipe(outputParser);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import {
GEMINI_SYSTEM_PROMPT,
GEMINI_USER_PROMPT,
STRUCTURED_SYSTEM_PROMPT,
BEDROCK_CHAT_TITLE,
GEMINI_CHAT_TITLE,
DEFAULT_CHAT_TITLE,
} from './prompts';

export const promptGroupId = {
Expand All @@ -31,6 +34,7 @@ export const promptGroupId = {
export const promptDictionary = {
systemPrompt: `systemPrompt`,
userPrompt: `userPrompt`,
chatTitle: `chatTitle`,
attackDiscoveryDefault: `default`,
attackDiscoveryRefine: `refine`,
attackDiscoveryContinue: `continue`,
Expand Down Expand Up @@ -154,4 +158,27 @@ export const localPrompts: Prompt[] = [
default: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
},
},
{
promptId: promptDictionary.chatTitle,
promptGroupId: promptGroupId.aiAssistant,
prompt: {
default: DEFAULT_CHAT_TITLE,
},
},
{
promptId: promptDictionary.chatTitle,
promptGroupId: promptGroupId.aiAssistant,
provider: 'bedrock',
prompt: {
default: BEDROCK_CHAT_TITLE,
},
},
{
promptId: promptDictionary.chatTitle,
promptGroupId: promptGroupId.aiAssistant,
provider: 'gemini',
prompt: {
default: GEMINI_CHAT_TITLE,
},
},
];
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,21 @@ export const ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN = `A markdown summary
export const ATTACK_DISCOVERY_GENERATION_TITLE =
'A short, no more than 7 words, title for the insight, NOT formatted with special syntax or markdown. This must be as brief as possible.';
export const ATTACK_DISCOVERY_GENERATION_INSIGHTS = `Insights with markdown that always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`;

export const BEDROCK_CHAT_TITLE = `You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Respond with the title only with no other text explaining your response. As an example, for the given MESSAGE, this is the TITLE:
MESSAGE: I am having trouble with the Elastic Security app.
TITLE: Troubleshooting Elastic Security app issues
`;

export const GEMINI_CHAT_TITLE = `You are a title generator for a helpful assistant for Elastic Security. Assume the following human message is the start of a conversation between you and a human. Generate a relevant conversation title for the human's message in plain text. Make sure the title is formatted for the user, without using quotes or markdown. The title should clearly reflect the content of the message and be appropriate for a list of conversations. Respond only with the title. As an example, for the given MESSAGE, this is the TITLE:
MESSAGE: I am having trouble with the Elastic Security app.
TITLE: Troubleshooting Elastic Security app issues
`;

export const DEFAULT_CHAT_TITLE = `You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. As an example, for the given MESSAGE, this is the TITLE:
MESSAGE: I am having trouble with the Elastic Security app.
TITLE: Troubleshooting Elastic Security app issues
`;
Loading

0 comments on commit 092c2cd

Please sign in to comment.