Skip to content

Commit

Permalink
Magic
Browse files Browse the repository at this point in the history
  • Loading branch information
dgieselaar committed Sep 17, 2024
1 parent 029eb9f commit 9178f90
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 14 deletions.
4 changes: 2 additions & 2 deletions x-pack/plugins/inference/common/output/create_output_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
...(schema
? {
tools: {
output: {
structuredOutput: {
description: `Use the following schema to respond to the user's request in structured data, so it can be parsed and handled.`,
schema,
},
},
toolChoice: { function: 'output' as const },
toolChoice: { function: 'structuredOutput' as const },
}
: {}),
}).pipe(
Expand Down
5 changes: 5 additions & 0 deletions x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { ToolSchema, generateFakeToolCallId, isChatCompletionMessageEvent } from
import {
ChatCompletionChunkEvent,
ChatCompletionMessageEvent,
FunctionCallingMode,
Message,
MessageRole,
} from '../../../common/chat_complete';
Expand All @@ -38,11 +39,13 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
tools,
toolChoice,
logger,
functionCalling,
...rest
}: {
client: Pick<InferenceClient, 'output' | 'chatComplete'>;
connectorId: string;
logger: Pick<Logger, 'debug'>;
functionCalling?: FunctionCallingMode;
} & TToolOptions &
({ input: string } | { messages: Message[] })): Observable<NlToEsqlTaskEvent<TToolOptions>> {
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none;
Expand Down Expand Up @@ -130,6 +133,7 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({
}),
client
.chatComplete({
functionCalling,
connectorId,
system: `${systemMessage}
Expand Down Expand Up @@ -233,6 +237,7 @@ export function naturalLanguageToEsql<TToolOptions extends ToolOptions>({

return client
.output('request_documentation', {
functionCalling,
connectorId,
system: systemMessage,
previousMessages: messages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ describe('chatFunctionClient', () => {
messages: [],
signal: new AbortController().signal,
connectorId: 'foo',
useSimulatedFunctionCalling: false,
});
}).rejects.toThrowError(`Function arguments are invalid`);

Expand Down Expand Up @@ -109,6 +110,7 @@ describe('chatFunctionClient', () => {
messages: [],
signal: new AbortController().signal,
connectorId: 'foo',
useSimulatedFunctionCalling: false,
});

expect(result).toEqual({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,15 @@ export class ChatFunctionClient {
messages,
signal,
connectorId,
useSimulatedFunctionCalling,
}: {
chat: FunctionCallChatFunction;
name: string;
args: string | undefined;
messages: Message[];
signal: AbortSignal;
connectorId: string;
useSimulatedFunctionCalling: boolean;
}): Promise<FunctionResponse> {
const fn = this.functionRegistry.get(name);

Expand All @@ -172,6 +174,7 @@ export class ChatFunctionClient {
screenContexts: this.screenContexts,
chat,
connectorId,
useSimulatedFunctionCalling,
},
signal
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ export class ObservabilityAIAssistantClient {
disableFunctions,
tracer: completeTracer,
connectorId,
useSimulatedFunctionCalling: simulateFunctionCalling !== false,
})
);
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function executeFunctionAndCatchError({
logger,
tracer,
connectorId,
useSimulatedFunctionCalling,
}: {
name: string;
args: string | undefined;
Expand All @@ -64,6 +65,7 @@ function executeFunctionAndCatchError({
logger: Logger;
tracer: LangTracer;
connectorId: string;
useSimulatedFunctionCalling: boolean;
}): Observable<MessageOrChatEvent> {
// hide token count events from functions to prevent them from
// having to deal with it as well
Expand All @@ -84,6 +86,7 @@ function executeFunctionAndCatchError({
signal,
messages,
connectorId,
useSimulatedFunctionCalling,
})
);

Expand Down Expand Up @@ -181,6 +184,7 @@ export function continueConversation({
disableFunctions,
tracer,
connectorId,
useSimulatedFunctionCalling,
}: {
messages: Message[];
functionClient: ChatFunctionClient;
Expand All @@ -197,6 +201,7 @@ export function continueConversation({
};
tracer: LangTracer;
connectorId: string;
useSimulatedFunctionCalling: boolean;
}): Observable<MessageOrChatEvent> {
let nextFunctionCallsLeft = functionCallsLeft;

Expand Down Expand Up @@ -310,6 +315,7 @@ export function continueConversation({
logger,
tracer,
connectorId,
useSimulatedFunctionCalling,
});
}

Expand Down Expand Up @@ -338,6 +344,7 @@ export function continueConversation({
disableFunctions,
tracer,
connectorId,
useSimulatedFunctionCalling,
});
})
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type RespondFunction<TArguments, TResponse extends FunctionResponse> = (
screenContexts: ObservabilityAIAssistantScreenContextRequest[];
chat: FunctionCallChatFunction;
connectorId: string;
useSimulatedFunctionCalling: boolean;
},
signal: AbortSignal
) => Promise<TResponse>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export function convertMessagesForInference(messages: Message[]): InferenceMessa
inferenceMessages.push({
role: InferenceMessageRole.Assistant,
content: message.message.content ?? null,
...(message.message.function_call
...(message.message.function_call?.name
? {
toolCalls: [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,16 @@ export interface VisualizeQueryResponsev1 {
};
}

export type VisualizeQueryResponse = VisualizeQueryResponsev0 | VisualizeQueryResponsev1;
export type VisualizeQueryResponsev2 = VisualizeQueryResponsev1 & {
data: {
correctedQuery: string;
};
};

export type VisualizeQueryResponse =
| VisualizeQueryResponsev0
| VisualizeQueryResponsev1
| VisualizeQueryResponsev2;

export type VisualizeESQLFunctionArguments = FromSchema<
(typeof visualizeESQLFunction)['parameters']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,11 @@ export function registerVisualizeQueryRenderFunction({
? typedResponse.content.errorMessages
: [];

const correctedQuery =
'data' in typedResponse && 'correctedQuery' in typedResponse.data
? typedResponse.data.correctedQuery
: query;

if ('data' in typedResponse && 'userOverrides' in typedResponse.data) {
userOverrides = typedResponse.data.userOverrides;
}
Expand Down Expand Up @@ -472,7 +477,7 @@ export function registerVisualizeQueryRenderFunction({
break;
}

const trimmedQuery = query.trim();
const trimmedQuery = correctedQuery.trim();

return (
<VisualizeESQL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
* 2.0.
*/

import { isChatCompletionChunkEvent, isOutputEvent } from '@kbn/inference-plugin/common';
import {
correctCommonEsqlMistakes,
isChatCompletionChunkEvent,
isOutputEvent,
} from '@kbn/inference-plugin/common';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import {
FunctionVisibility,
Expand Down Expand Up @@ -74,9 +78,11 @@ export function registerQueryFunction({
} as const,
},
async ({ arguments: { query } }) => {
const correctedQuery = correctCommonEsqlMistakes(query).output;

const client = (await resources.context.core).elasticsearch.client.asCurrentUser;
const { error, errorMessages, rows, columns } = await runAndValidateEsqlQuery({
query,
query: correctedQuery,
client,
});

Expand Down Expand Up @@ -108,7 +114,7 @@ export function registerQueryFunction({
function takes no input.`,
visibility: FunctionVisibility.AssistantOnly,
},
async ({ messages, connectorId }, signal) => {
async ({ messages, connectorId, useSimulatedFunctionCalling }, signal) => {
const esqlFunctions = functions
.getFunctions()
.filter(
Expand All @@ -132,6 +138,7 @@ export function registerQueryFunction({
.concat(esqlFunctions)
.map((fn) => [fn.name, { description: fn.description, schema: fn.parameters }])
),
functionCalling: useSimulatedFunctionCalling ? 'simulated' : 'native',
});

const chatMessageId = v4();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@ export async function runAndValidateEsqlQuery({
error?: Error;
errorMessages?: string[];
}> {
const { errors } = await validateQuery(query, getAstAndSyntaxErrors, {
const queryWithoutLineBreaks = query.replaceAll(/\n/g, '');

const { errors } = await validateQuery(queryWithoutLineBreaks, getAstAndSyntaxErrors, {
// setting this to true, we don't want to validate the index / fields existence
ignoreOnMissingCallbacks: true,
});

const asCommands = splitIntoCommands(query);
const asCommands = splitIntoCommands(queryWithoutLineBreaks);

const errorMessages = errors?.map((error) => {
if ('location' in error) {
const commandsUntilEndOfError = splitIntoCommands(query.substring(0, error.location.max));
const commandsUntilEndOfError = splitIntoCommands(
queryWithoutLineBreaks.substring(0, error.location.max)
);
const lastCompleteCommand = asCommands[commandsUntilEndOfError.length - 1];
if (lastCompleteCommand) {
return `Error in ${lastCompleteCommand.command}\n: ${error.text}`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
* 2.0.
*/
import { VisualizeESQLUserIntention } from '@kbn/observability-ai-assistant-plugin/common/functions/visualize_esql';
import { correctCommonEsqlMistakes } from '@kbn/inference-plugin/common';
import {
visualizeESQLFunction,
type VisualizeQueryResponsev1,
VisualizeQueryResponsev2,
} from '../../common/functions/visualize_esql';
import type { FunctionRegistrationParameters } from '.';
import { runAndValidateEsqlQuery } from './query/validate_esql_query';
Expand All @@ -32,12 +33,15 @@ export function registerVisualizeESQLFunction({
}: FunctionRegistrationParameters) {
functions.registerFunction(
visualizeESQLFunction,
async ({ arguments: { query, intention } }): Promise<VisualizeQueryResponsev1> => {
async ({ arguments: { query, intention } }): Promise<VisualizeQueryResponsev2> => {
// errorMessages contains the syntax errors from the client side valdation
// error contains the error from the server side validation, it is always one error
// and help us identify errors like index not found, field not found etc.

const correctedQuery = correctCommonEsqlMistakes(query).output;

const { columns, errorMessages, rows, error } = await runAndValidateEsqlQuery({
query,
query: correctedQuery,
client: (await resources.context.core).elasticsearch.client.asCurrentUser,
});

Expand All @@ -47,6 +51,7 @@ export function registerVisualizeESQLFunction({
data: {
columns: columns ?? [],
rows: rows ?? [],
correctedQuery,
},
content: {
message,
Expand Down

0 comments on commit 9178f90

Please sign in to comment.