diff --git a/package.json b/package.json index d9506aac45273..f149ff4fac1f8 100644 --- a/package.json +++ b/package.json @@ -879,6 +879,7 @@ "js-search": "^1.4.3", "js-sha256": "^0.9.0", "js-yaml": "^3.14.1", + "json-schema-to-ts": "^2.9.1", "json-stable-stringify": "^1.0.1", "json-stringify-pretty-compact": "1.2.0", "json-stringify-safe": "5.0.1", diff --git a/x-pack/plugins/actions/server/sub_action_framework/sub_action_connector.ts b/x-pack/plugins/actions/server/sub_action_framework/sub_action_connector.ts index a795b0eedc2ac..11419cb4613ed 100644 --- a/x-pack/plugins/actions/server/sub_action_framework/sub_action_connector.ts +++ b/x-pack/plugins/actions/server/sub_action_framework/sub_action_connector.ts @@ -11,6 +11,8 @@ import { Logger } from '@kbn/logging'; import axios, { AxiosInstance, AxiosResponse, AxiosError, AxiosRequestHeaders } from 'axios'; import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; import { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; +import { finished } from 'stream/promises'; +import { IncomingMessage } from 'http'; import { assertURL } from './helpers/validators'; import { ActionsConfigurationUtilities } from '../actions_config'; import { SubAction, SubActionRequestParams } from './types'; @@ -140,6 +142,19 @@ export abstract class SubActionConnector { `Request to external service failed. Connector Id: ${this.connector.id}. Connector type: ${this.connector.type}. Method: ${error.config.method}. URL: ${error.config.url}` ); + let responseBody = ''; + + // The error response body may also be a stream, e.g. for the GenAI connector + if (error.response?.config.responseType === 'stream' && error.response?.data) { + const incomingMessage = error.response.data as IncomingMessage; + + incomingMessage.on('data', (chunk) => { + responseBody += chunk.toString(); + }); + await finished(incomingMessage); + error.response.data = JSON.parse(responseBody); + } + const errorMessage = `Status code: ${ error.status ?? error.response?.status }. Message: ${this.getResponseErrorMessage(error)}`; diff --git a/x-pack/plugins/observability_ai_assistant/common/types.ts b/x-pack/plugins/observability_ai_assistant/common/types.ts index 134bdc7b606e1..78b90d8bd551f 100644 --- a/x-pack/plugins/observability_ai_assistant/common/types.ts +++ b/x-pack/plugins/observability_ai_assistant/common/types.ts @@ -5,7 +5,10 @@ * 2.0. */ -import { Serializable } from '@kbn/utility-types'; +import type { Serializable } from '@kbn/utility-types'; +import type { FromSchema } from 'json-schema-to-ts'; +import type { JSONSchema } from 'json-schema-to-ts'; +import React from 'react'; export enum MessageRole { System = 'system', @@ -24,10 +27,10 @@ export interface Message { role: MessageRole; function_call?: { name: string; - args?: Serializable; + arguments?: string; trigger: MessageRole.Assistant | MessageRole.User | MessageRole.Elastic; }; - data?: Serializable; + data?: string; }; } @@ -54,3 +57,51 @@ export type ConversationRequestBase = Omit; + +export interface ContextDefinition { + name: string; + description: string; +} + +interface FunctionResponse { + content?: Serializable; + data?: Serializable; +} + +interface FunctionOptions { + name: string; + description: string; + parameters: TParameters; + contexts: string[]; +} + +type RespondFunction< + TParameters extends CompatibleJSONSchema, + TResponse extends FunctionResponse +> = (options: { arguments: FromSchema }, signal: AbortSignal) => Promise; + +type RenderFunction = (options: { + response: TResponse; +}) => React.ReactNode; + +export interface FunctionDefinition { + options: FunctionOptions; + respond: (options: { arguments: any }, signal: AbortSignal) => Promise; + render?: RenderFunction; +} + +export type RegisterContextDefinition = (options: ContextDefinition) => void; + +export type RegisterFunctionDefinition = < + TParameters extends CompatibleJSONSchema, + TResponse extends FunctionResponse +>( + options: FunctionOptions, + respond: RespondFunction, + render?: RenderFunction +) => void; + +export type ContextRegistry = Map; +export type FunctionRegistry = Map; diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx index 3a861b9222671..39309e67299ac 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx @@ -7,6 +7,8 @@ import { ComponentStory } from '@storybook/react'; import React from 'react'; +import { Observable } from 'rxjs'; +import { ObservabilityAIAssistantService } from '../../types'; import { ChatBody as Component } from './chat_body'; export default { @@ -54,13 +56,11 @@ const defaultProps: ChatBodyProps = { currentUser: { username: 'elastic', }, - chat: { - loading: false, - abort: () => {}, - generate: async () => { - return {} as any; + service: { + chat: () => { + return new Observable(); }, - }, + } as unknown as ObservabilityAIAssistantService, }; export const ChatBody = Template.bind({}); diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx index 3e9c7fc56a1df..1bac96a9ddf60 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx @@ -10,9 +10,9 @@ import { css } from '@emotion/css'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; import React from 'react'; import { type ConversationCreateRequest } from '../../../common/types'; -import type { UseChatResult } from '../../hooks/use_chat'; import type { UseGenAIConnectorsResult } from '../../hooks/use_genai_connectors'; import { useTimeline } from '../../hooks/use_timeline'; +import { ObservabilityAIAssistantService } from '../../types'; import { HideExpandConversationListButton } from '../buttons/hide_expand_conversation_list_button'; import { ChatHeader } from './chat_header'; import { ChatPromptEditor } from './chat_prompt_editor'; @@ -30,14 +30,14 @@ export function ChatBody({ initialConversation, connectors, currentUser, - chat, + service, isConversationListExpanded, onToggleExpandConversationList, }: { initialConversation?: ConversationCreateRequest; connectors: UseGenAIConnectorsResult; currentUser?: Pick; - chat: UseChatResult; + service: ObservabilityAIAssistantService; isConversationListExpanded?: boolean; onToggleExpandConversationList?: () => void; }) { @@ -47,7 +47,7 @@ export function ChatBody({ initialConversation, connectors, currentUser, - chat, + service, }); return ( @@ -93,7 +93,7 @@ export function ChatBody({ diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx index 00fac0dc2bdf0..abd395bc1e03a 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx @@ -7,9 +7,9 @@ import React, { useState } from 'react'; import { EuiFlexGroup, EuiFlexItem, EuiFlyout, useEuiTheme } from '@elastic/eui'; import type { ConversationCreateRequest } from '../../../common/types'; -import { useChat } from '../../hooks/use_chat'; import { useCurrentUser } from '../../hooks/use_current_user'; import { useGenAIConnectors } from '../../hooks/use_genai_connectors'; +import { useObservabilityAIAssistant } from '../../hooks/use_observability_ai_assistant'; import { ChatBody } from './chat_body'; import { ConversationList } from './conversation_list'; @@ -22,11 +22,11 @@ export function ChatFlyout({ isOpen: boolean; onClose: () => void; }) { - const currentUser = useCurrentUser(); const connectors = useGenAIConnectors(); - const { euiTheme } = useEuiTheme(); - const chat = useChat(); + const currentUser = useCurrentUser(); + + const { euiTheme } = useEuiTheme(); const [isConversationListExpanded, setIsConversationListExpanded] = useState(false); @@ -34,6 +34,8 @@ export function ChatFlyout({ const handleClickNewChat = () => {}; const handleClickSettings = () => {}; + const service = useObservabilityAIAssistant(); + return isOpen ? ( @@ -51,7 +53,7 @@ export function ChatFlyout({ ) : null} void; } +const euiCommentClassName = css` + .euiCommentEvent__headerEvent { + flex-grow: 1; + } + + > div:last-child { + overflow: hidden; + } +`; + export function ChatItem({ title, content, @@ -143,15 +154,16 @@ export function ChatItem({ title={title} /> } - timelineAvatar={} + className={euiCommentClassName} + timelineAvatar={ + + } username={getRoleTranslation(role)} > - {content !== undefined || error || loading ? ( + {content || error || loading || controls ? ( - ) : null + content || loading ? : null } error={error} controls={controls} diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_avatar.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_avatar.tsx index 6afa522d426d2..d04f818bb204c 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_avatar.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_avatar.tsx @@ -15,16 +15,19 @@ import { MessageRole } from '../../../common/types'; interface ChatAvatarProps { currentUser?: Pick | undefined; role: MessageRole; + loading: boolean; } -export function ChatItemAvatar({ currentUser, role }: ChatAvatarProps) { +export function ChatItemAvatar({ currentUser, role, loading }: ChatAvatarProps) { + const isLoading = loading || !currentUser; + + if (isLoading) { + return ; + } + switch (role) { case MessageRole.User: - return currentUser ? ( - - ) : ( - - ); + return ; case MessageRole.Assistant: case MessageRole.Elastic: diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx index 6e06590ed3621..a23c06561fffc 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx @@ -16,15 +16,13 @@ import { EuiPopover, } from '@elastic/eui'; import { i18n } from '@kbn/i18n'; -import { useFunctions, Func } from '../../hooks/use_functions'; +import { useFunctions, type Func } from '../../hooks/use_functions'; +import { type Message, MessageRole } from '../../../common'; export interface ChatPromptEditorProps { disabled: boolean; loading: boolean; - onSubmit: (message: { - content?: string; - function_call?: { name: string; args?: string }; - }) => Promise; + onSubmit: (message: Message) => Promise; } export function ChatPromptEditor({ onSubmit, disabled, loading }: ChatPromptEditorProps) { @@ -40,7 +38,10 @@ export function ChatPromptEditor({ onSubmit, disabled, loading }: ChatPromptEdit const handleSubmit = () => { const currentPrompt = prompt; setPrompt(''); - onSubmit({ content: currentPrompt }) + onSubmit({ + '@timestamp': new Date().toISOString(), + message: { role: MessageRole.User, content: currentPrompt }, + }) .then(() => { setPrompt(''); }) diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx index db671271176fe..1f4b4cf3ec091 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx @@ -5,23 +5,17 @@ * 2.0. */ -import React from 'react'; import { EuiCommentList } from '@elastic/eui'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; -import { MessageRole } from '../../../common/types'; +import React from 'react'; +import type { Message } from '../../../common'; import type { Feedback } from '../feedback_buttons'; import { ChatItem } from './chat_item'; -export interface ChatTimelineItem { +export interface ChatTimelineItem + extends Pick { id: string; title: string; - role: MessageRole; - content?: string; - function_call?: { - name: string; - args?: string; - trigger?: MessageRole; - }; loading: boolean; error?: any; canEdit: boolean; diff --git a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx index 5c4e3345b07d7..b9fab3266cde2 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx @@ -4,43 +4,63 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import React, { useEffect, useMemo, useState } from 'react'; -import { useKibana } from '@kbn/kibana-react-plugin/public'; import { EuiFlexGroup, EuiFlexItem } from '@elastic/eui'; -import { type ConversationCreateRequest, type Message, MessageRole } from '../../../common/types'; -import { useChat } from '../../hooks/use_chat'; +import { useKibana } from '@kbn/kibana-react-plugin/public'; +import { AbortError } from '@kbn/kibana-utils-plugin/common'; +import React, { useCallback, useEffect, useMemo, useState } from 'react'; +import type { Subscription } from 'rxjs'; +import { MessageRole, type Message } from '../../../common/types'; import { useGenAIConnectors } from '../../hooks/use_genai_connectors'; +import { useObservabilityAIAssistant } from '../../hooks/use_observability_ai_assistant'; +import type { PendingMessage } from '../../types'; +import { ChatFlyout } from '../chat/chat_flyout'; import { ConnectorSelectorBase } from '../connector_selector/connector_selector_base'; import { MessagePanel } from '../message_panel/message_panel'; import { MessageText } from '../message_panel/message_text'; -import { InsightBase } from './insight_base'; -import { InsightMissingCredentials } from './insight_missing_credentials'; -import { StopGeneratingButton } from '../buttons/stop_generating_button'; import { RegenerateResponseButton } from '../buttons/regenerate_response_button'; import { StartChatButton } from '../buttons/start_chat_button'; -import { ChatFlyout } from '../chat/chat_flyout'; +import { StopGeneratingButton } from '../buttons/stop_generating_button'; +import { InsightBase } from './insight_base'; +import { InsightMissingCredentials } from './insight_missing_credentials'; function ChatContent({ messages, connectorId }: { messages: Message[]; connectorId: string }) { - const chat = useChat(); + const service = useObservabilityAIAssistant(); - const { generate } = chat; + const [pendingMessage, setPendingMessage] = useState(); + const [loading, setLoading] = useState(false); + const [subscription, setSubscription] = useState(); - useEffect(() => { - generate({ messages, connectorId }).catch(() => { - // error is handled in chat, and we don't do anything with the full response for now. + const reloadReply = useCallback(() => { + setLoading(true); + + const nextSubscription = service.chat({ messages, connectorId }).subscribe({ + next: (msg) => { + setPendingMessage(() => msg); + }, + complete: () => { + setLoading(false); + }, }); - }, [generate, messages, connectorId]); - const initialConversation = useMemo(() => { + setSubscription(nextSubscription); + }, [messages, connectorId, service]); + + useEffect(() => { + reloadReply(); + }, [reloadReply]); + + const [isOpen, setIsOpen] = useState(false); + + const initialConversation = useMemo(() => { const time = new Date().toISOString(); return { '@timestamp': time, - messages: chat.content + messages: pendingMessage?.message.content ? messages.concat({ '@timestamp': time, message: { role: MessageRole.Assistant, - content: chat.content, + content: pendingMessage.message.content, }, }) : messages, @@ -50,20 +70,27 @@ function ChatContent({ messages, connectorId }: { messages: Message[]; connector labels: {}, numeric_labels: {}, }; - }, [messages, chat.content]); - - const [isOpen, setIsOpen] = useState(false); + }, [pendingMessage, messages]); return ( <> } - error={chat.error} + body={} + error={pendingMessage?.error} controls={ - chat.loading ? ( + loading ? ( { - chat.abort(); + subscription?.unsubscribe(); + setLoading(false); + setPendingMessage((prev) => ({ + message: { + role: MessageRole.Assistant, + ...prev?.message, + }, + aborted: true, + error: new AbortError(), + })); }} /> ) : ( @@ -71,7 +98,7 @@ function ChatContent({ messages, connectorId }: { messages: Message[]; connector { - generate({ messages, connectorId }); + reloadReply(); }} /> diff --git a/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_text.tsx b/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_text.tsx index aa73cce89e139..76be222d6fed2 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_text.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_text.tsx @@ -17,6 +17,10 @@ interface Props { loading: boolean; } +const containerClassName = css` + overflow-wrap: break-word; +`; + const cursorCss = css` @keyframes blink { 0% { @@ -81,7 +85,7 @@ const loadingCursorPlugin = () => { export function MessageText(props: Props) { return ( - + ; -const mockUseObservabilityAIAssistant = useObservabilityAIAssistant as jest.MockedFunction< - typeof useObservabilityAIAssistant ->; - -const mockChat = jest.fn(); - -mockUseObservabilityAIAssistant.mockImplementation( - () => - ({ - chat: mockChat, - } as unknown as ObservabilityAIAssistantService) -); - -function mockDeltas(deltas: Array>) { - mockResponse( - Promise.resolve( - new Observable((subscriber) => { - async function simulateDelays() { - for (const delta of deltas) { - await new Promise((resolve) => { - setTimeout(() => { - subscriber.next({ - choices: [ - { - role: 'assistant', - delta, - }, - ], - }); - resolve(); - }, 100); - }); - } - subscriber.complete(); - } - - simulateDelays(); - }) - ) - ); -} - -function mockResponse(response: Promise) { - mockChat.mockReturnValueOnce(response); -} - -describe('useChat', () => { - beforeEach(() => { - mockUseKibana.mockReturnValue({ - services: { notifications: { showErrorDialog: jest.fn() } }, - } as any); - }); - - it('returns the result of the chat API', async () => { - mockDeltas([{ content: 'testContent' }]); - const { result, waitFor } = renderHook(() => useChat()); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }); - }); - - expect(result.current.loading).toBeTruthy(); - expect(result.current.error).toBeUndefined(); - expect(result.current.content).toBeUndefined(); - - await waitFor(() => result.current.loading === false, WAIT_OPTIONS); - - expect(result.current.error).toBeUndefined(); - expect(result.current.content).toBe('testContent'); - }); - - it('handles 4xx and 5xx', async () => { - mockResponse(Promise.reject(new Error())); - const { result, waitFor } = renderHook(() => useChat()); - - const catchMock = jest.fn(); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }).catch(catchMock); - }); - - await waitFor(() => result.current.loading === false, WAIT_OPTIONS); - - expect(catchMock).toHaveBeenCalled(); - - expect(result.current.error).toBeInstanceOf(Error); - expect(result.current.content).toBeUndefined(); - - expect(mockUseKibana().services.notifications?.showErrorDialog).toHaveBeenCalled(); - }); - - it('handles valid responses but generation errors', async () => { - mockResponse( - Promise.resolve( - new Observable((subscriber) => { - subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); - setTimeout(() => { - subscriber.error(new Error()); - }, 100); - }) - ) - ); - - const { result, waitFor } = renderHook(() => useChat()); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }).catch(() => {}); - }); - - await waitFor(() => result.current.loading === false, WAIT_OPTIONS); - - expect(result.current.loading).toBe(false); - expect(result.current.error).toBeInstanceOf(Error); - expect(result.current.content).toBe('foo'); - - expect(mockUseKibana().services.notifications?.showErrorDialog).toHaveBeenCalled(); - }); - - it('handles aborted requests', async () => { - mockResponse( - Promise.resolve( - new Observable((subscriber) => { - subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); - }) - ) - ); - - const { result, waitFor, unmount } = renderHook(() => useChat()); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }); - }); - - await waitFor(() => result.current.content === 'foo', WAIT_OPTIONS); - - unmount(); - - expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); - }); - - it('handles regenerations triggered by updates', async () => { - mockResponse( - Promise.resolve( - new Observable((subscriber) => { - subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); - }) - ) - ); - - const { result, waitFor } = renderHook(() => useChat()); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }); - }); - - await waitFor(() => result.current.content === 'foo', WAIT_OPTIONS); - - mockDeltas([{ content: 'bar' }]); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }); - }); - - await waitFor(() => result.current.loading === false, WAIT_OPTIONS); - - expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); - - expect(result.current.content).toBe('bar'); - }); - - it('handles streaming updates', async () => { - mockDeltas([ - { - content: 'my', - }, - { - content: ' ', - }, - { - content: 'update', - }, - ]); - - const { result, waitForNextUpdate } = renderHook(() => useChat()); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }); - }); - - await waitForNextUpdate(WAIT_OPTIONS); - - expect(result.current.content).toBe('my'); - - await waitForNextUpdate(WAIT_OPTIONS); - - expect(result.current.content).toBe('my '); - - await waitForNextUpdate(WAIT_OPTIONS); - - expect(result.current.content).toBe('my update'); - }); - - it('handles user aborts', async () => { - const thenMock = jest.fn(); - const catchMock = jest.fn(); - - mockResponse( - Promise.resolve( - new Observable((subscriber) => { - subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); - }) - ) - ); - - const { result, waitForNextUpdate, waitFor } = renderHook(() => useChat()); - - act(() => { - result.current - .generate({ messages: [], connectorId: 'myConnectorId' }) - .then(thenMock, catchMock); - }); - - await waitForNextUpdate(WAIT_OPTIONS); - - act(() => { - result.current.abort(); - }); - - await waitFor(() => thenMock.mock.calls.length > 0); - - expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); - - expect(result.current.content).toBe('foo'); - expect(result.current.loading).toBe(false); - expect(result.current.error).toBeInstanceOf(AbortError); - - expect(thenMock).toHaveBeenCalledWith({ - aborted: true, - content: 'foo', - function_call: { - args: '', - name: '', - }, - }); - - expect(catchMock).not.toHaveBeenCalled(); - }); - - it('handles user regenerations', async () => { - mockResponse( - Promise.resolve( - new Observable((subscriber) => { - subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); - }) - ) - ); - - const { result, waitForNextUpdate } = renderHook(() => useChat()); - - act(() => { - result.current.generate({ messages: [], connectorId: 'myConnectorId' }); - }); - - await waitForNextUpdate(WAIT_OPTIONS); - - act(() => { - mockDeltas([{ content: 'bar' }]); - result.current.generate({ messages: [], connectorId: 'mySecondConnectorId' }); - }); - - await waitForNextUpdate(WAIT_OPTIONS); - - expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); - - expect(result.current.content).toBe('bar'); - expect(result.current.loading).toBe(false); - expect(result.current.error).toBeUndefined(); - }); -}); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts deleted file mode 100644 index 32232c011d989..0000000000000 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { i18n } from '@kbn/i18n'; -import { useKibana } from '@kbn/kibana-react-plugin/public'; -import { AbortError } from '@kbn/kibana-utils-plugin/common'; -import { clone } from 'lodash'; -import { useCallback, useEffect, useRef, useState } from 'react'; -import { concatMap, delay, of } from 'rxjs'; -import type { Message } from '../../common/types'; -import { useObservabilityAIAssistant } from './use_observability_ai_assistant'; - -interface MessageResponse { - content?: string; - function_call?: { - name: string; - args?: string; - }; -} - -export interface UseChatResult { - content?: string; - function_call?: { - name: string; - args?: string; - }; - loading: boolean; - error?: Error; - abort: () => void; - generate: (options: { messages: Message[]; connectorId: string }) => Promise<{ - content?: string; - function_call?: { name: string; args?: string }; - aborted?: boolean; - }>; -} - -export function useChat(): UseChatResult { - const assistant = useObservabilityAIAssistant(); - - const { - services: { notifications }, - } = useKibana(); - - const [response, setResponse] = useState(undefined); - - const [error, setError] = useState(undefined); - - const [loading, setLoading] = useState(false); - - const controllerRef = useRef(new AbortController()); - - const generate = useCallback( - ({ messages, connectorId }: { messages: Message[]; connectorId: string }) => { - controllerRef.current.abort(); - - const controller = (controllerRef.current = new AbortController()); - - setResponse(undefined); - setError(undefined); - setLoading(true); - - const partialResponse = { - content: '', - function_call: { - name: '', - args: '', - }, - }; - - return assistant - .chat({ messages, connectorId, signal: controller.signal }) - .then((response$) => { - return new Promise((resolve, reject) => { - const subscription = response$ - .pipe(concatMap((value) => of(value).pipe(delay(50)))) - .subscribe({ - next: (chunk) => { - if (controller.signal.aborted) { - return; - } - partialResponse.content += chunk.choices[0].delta.content ?? ''; - partialResponse.function_call.name += - chunk.choices[0].delta.function_call?.name ?? ''; - partialResponse.function_call.args += - chunk.choices[0].delta.function_call?.args ?? ''; - setResponse(clone(partialResponse)); - }, - error: (err) => { - reject(err); - }, - complete: () => { - resolve(); - }, - }); - - controller.signal.addEventListener('abort', () => { - subscription.unsubscribe(); - reject(new AbortError()); - }); - }); - }) - .then(() => { - return Promise.resolve(partialResponse); - }) - .catch((err) => { - if (controller.signal.aborted) { - return Promise.resolve({ - ...partialResponse, - aborted: true, - }); - } - notifications?.showErrorDialog({ - title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadChatTitle', { - defaultMessage: 'Failed to load chat', - }), - error: err, - }); - setError(err); - throw err; - }) - .finally(() => { - if (controller.signal.aborted) { - return; - } - setLoading(false); - }); - }, - [assistant, notifications] - ); - - useEffect(() => { - controllerRef.current.abort(); - }, []); - - return { - ...response, - error, - loading, - abort: () => { - setLoading(false); - setError(new AbortError()); - controllerRef.current.abort(); - }, - generate, - }; -} diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts index 08ee9df3a3085..fac5a5a9c9093 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts @@ -10,15 +10,14 @@ import { type Renderer, type RenderHookResult, } from '@testing-library/react-hooks'; -import { merge } from 'lodash'; -import { DeepPartial } from 'utility-types'; +import { BehaviorSubject, Subject } from 'rxjs'; +import { AbortError } from '@kbn/kibana-utils-plugin/common'; import { MessageRole } from '../../common'; -import { createNewConversation, useTimeline, UseTimelineResult } from './use_timeline'; +import { PendingMessage } from '../types'; +import { useTimeline, UseTimelineResult } from './use_timeline'; type HookProps = Parameters[0]; -const WAIT_OPTIONS = { timeout: 5000 }; - describe('useTimeline', () => { let hookResult: RenderHookResult>; @@ -27,7 +26,7 @@ describe('useTimeline', () => { hookResult = renderHook((props) => useTimeline(props), { initialProps: { connectors: {}, - chat: {}, + service: {}, } as HookProps, }); }); @@ -71,8 +70,10 @@ describe('useTimeline', () => { connectors: { selectedConnector: 'foo', }, - chat: {}, - } as HookProps, + service: { + chat: () => {}, + }, + } as unknown as HookProps, }); }); it('renders the correct timeline items', () => { @@ -103,92 +104,39 @@ describe('useTimeline', () => { }); describe('when submitting a new prompt', () => { - const createChatSimulator = (initialProps?: DeepPartial) => { - let resolve: (data: { content?: string; aborted?: boolean }) => void; + let subject: Subject; - const abort = () => { - resolve({ - content: props.chat.content, - aborted: true, - }); - rerender({ - chat: { - loading: false, + beforeEach(() => { + hookResult = renderHook((nextProps) => useTimeline(nextProps), { + initialProps: { + initialConversation: { + messages: [], }, - }); - }; - - let props = merge( - { - initialConversation: createNewConversation(), connectors: { - selectedConnector: 'myConnector', + selectedConnector: 'foo', }, - chat: { - loading: true, - content: undefined, - abort, - generate: () => { - const promise = new Promise((innerResolve) => { - resolve = (...args) => { - innerResolve(...args); - }; - }); - - rerender({ - chat: { + service: { + chat: jest.fn().mockImplementation(() => { + subject = new BehaviorSubject({ + message: { + role: MessageRole.Assistant, content: '', - loading: true, - error: undefined, - function_call: undefined, }, }); - return promise; - }, + return subject; + }), }, } as unknown as HookProps, - { - ...initialProps, - } - ); - - hookResult = renderHook((nextProps) => useTimeline(nextProps), { - initialProps: props, }); - - function rerender(nextProps: DeepPartial) { - props = merge({}, props, nextProps) as HookProps; - hookResult.rerender(props); - } - - return { - next: (nextValue: { content?: string }) => { - rerender({ - chat: { - content: nextValue.content, - }, - }); - }, - complete: () => { - resolve({ - content: props.chat.content, - }); - rerender({ - chat: { - loading: false, - }, - }); - }, - abort, - }; - }; + }); describe("and it's loading", () => { it('adds two items of which the last one is loading', async () => { - const simulator = createChatSimulator(); - act(() => { - hookResult.result.current.onSubmit({ content: 'Hello' }); + hookResult.result.current.onSubmit({ + '@timestamp': new Date().toISOString(), + message: { role: MessageRole.User, content: 'Hello' }, + }); }); expect(hookResult.result.current.items[0].role).toEqual(MessageRole.User); @@ -221,7 +169,7 @@ describe('useTimeline', () => { }); act(() => { - simulator.next({ content: 'Goodbye' }); + subject.next({ message: { role: MessageRole.Assistant, content: 'Goodbye' } }); }); expect(hookResult.result.current.items[2]).toMatchObject({ @@ -233,11 +181,9 @@ describe('useTimeline', () => { }); act(() => { - simulator.complete(); + subject.complete(); }); - await hookResult.waitForNextUpdate(WAIT_OPTIONS); - expect(hookResult.result.current.items[2]).toMatchObject({ role: MessageRole.Assistant, content: 'Goodbye', @@ -248,18 +194,23 @@ describe('useTimeline', () => { }); describe('and it being aborted', () => { - let simulator: ReturnType; - beforeEach(async () => { - simulator = createChatSimulator(); - act(() => { - hookResult.result.current.onSubmit({ content: 'Hello' }); - simulator.next({ content: 'My partial' }); - simulator.abort(); + hookResult.result.current.onSubmit({ + '@timestamp': new Date().toISOString(), + message: { role: MessageRole.User, content: 'Hello' }, + }); + subject.next({ message: { role: MessageRole.Assistant, content: 'My partial' } }); + subject.next({ + message: { + role: MessageRole.Assistant, + content: 'My partial', + }, + aborted: true, + error: new AbortError(), + }); + subject.complete(); }); - - await hookResult.waitForNextUpdate(WAIT_OPTIONS); }); it('adds the partial response', async () => { @@ -268,12 +219,13 @@ describe('useTimeline', () => { expect(hookResult.result.current.items[2]).toEqual({ canEdit: false, canRegenerate: true, - canGiveFeedback: true, + canGiveFeedback: false, content: 'My partial', id: expect.any(String), loading: false, title: '', role: MessageRole.Assistant, + error: expect.any(AbortError), }); }); @@ -281,6 +233,7 @@ describe('useTimeline', () => { beforeEach(() => { act(() => { hookResult.result.current.onRegenerate(hookResult.result.current.items[2]); + subject.next({ message: { role: MessageRole.Assistant, content: '' } }); }); }); @@ -303,8 +256,6 @@ describe('useTimeline', () => { hookResult.result.current.onStopGenerating(); }); - await hookResult.waitForNextUpdate(WAIT_OPTIONS); - act(() => { hookResult.result.current.onRegenerate(hookResult.result.current.items[2]); }); @@ -325,12 +276,10 @@ describe('useTimeline', () => { }); act(() => { - simulator.next({ content: 'Regenerated' }); - simulator.complete(); + subject.next({ message: { role: MessageRole.Assistant, content: 'Regenerated' } }); + subject.complete(); }); - await hookResult.waitForNextUpdate(WAIT_OPTIONS); - expect(hookResult.result.current.items.length).toBe(3); expect(hookResult.result.current.items[2]).toEqual({ diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts index 916168b10b4a1..25e6b624fbd88 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts @@ -6,13 +6,14 @@ */ import type { AuthenticatedUser } from '@kbn/security-plugin/common'; -import { omit } from 'lodash'; -import { useMemo, useState } from 'react'; +import { useEffect, useMemo, useRef, useState } from 'react'; +import type { Subscription } from 'rxjs'; +import { AbortError } from '@kbn/kibana-utils-plugin/common'; import { MessageRole, type ConversationCreateRequest, type Message } from '../../common/types'; import type { ChatPromptEditorProps } from '../components/chat/chat_prompt_editor'; import type { ChatTimelineProps } from '../components/chat/chat_timeline'; +import type { ObservabilityAIAssistantService, PendingMessage } from '../types'; import { getTimelineItemsfromConversation } from '../utils/get_timeline_items_from_conversation'; -import type { UseChatResult } from './use_chat'; import type { UseGenAIConnectorsResult } from './use_genai_connectors'; export function createNewConversation(): ConversationCreateRequest { @@ -37,12 +38,12 @@ export function useTimeline({ initialConversation, connectors, currentUser, - chat, + service, }: { initialConversation?: ConversationCreateRequest; connectors: UseGenAIConnectorsResult; currentUser?: Pick; - chat: UseChatResult; + service: ObservabilityAIAssistantService; }): UseTimelineResult { const connectorId = connectors.selectedConnector; @@ -58,50 +59,115 @@ export function useTimeline({ }); }, [conversation, currentUser, hasConnector]); + const [subscription, setSubscription] = useState(); + + const [pendingMessage, setPendingMessage] = useState(); + + const controllerRef = useRef(new AbortController()); + + function chat(messages: Message[]): Promise { + const controller = new AbortController(); + return new Promise((resolve, reject) => { + if (!connectorId) { + reject(new Error('Can not add a message without a connector')); + return; + } + + setConversation((conv) => ({ + ...conv, + messages, + })); + + const response$ = service.chat({ messages, connectorId }); + + let pendingMessageLocal = pendingMessage; + + const nextSubscription = response$.subscribe({ + next: (nextPendingMessage) => { + pendingMessageLocal = nextPendingMessage; + setPendingMessage(() => nextPendingMessage); + }, + error: reject, + complete: () => { + resolve(pendingMessageLocal!); + }, + }); + + setSubscription(() => { + controllerRef.current = controller; + return nextSubscription; + }); + }) + .then(async (nextMessage) => { + if (nextMessage.error) { + return; + } + if (nextMessage.aborted) { + return; + } + + setPendingMessage(undefined); + + const nextMessages = messages.concat({ + '@timestamp': new Date().toISOString(), + message: { + ...nextMessage.message, + }, + }); + + setConversation((conv) => ({ ...conv, messages: nextMessages })); + + if (nextMessage?.message.function_call?.name) { + const name = nextMessage.message.function_call.name; + + const message = await service.executeFunction( + name, + nextMessage.message.function_call.arguments, + controller.signal + ); + + await chat( + nextMessages.concat({ + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + name, + content: JSON.stringify(message.content), + data: JSON.stringify(message.data), + }, + }) + ); + } + }) + .catch((err) => {}); + } + const items = useMemo(() => { - if (chat.loading) { + if (pendingMessage) { return conversationItems.concat({ id: '', canEdit: false, - canRegenerate: !chat.loading, - canGiveFeedback: !chat.loading, - role: MessageRole.Assistant, + canRegenerate: pendingMessage.aborted || !!pendingMessage.error, + canGiveFeedback: false, title: '', - content: chat.content ?? '', - loading: chat.loading, + role: pendingMessage.message.role, + content: pendingMessage.message.content, + loading: !pendingMessage.aborted && !pendingMessage.error, + function_call: pendingMessage.message.function_call, currentUser, + error: pendingMessage.error, }); } return conversationItems; - }, [conversationItems, chat.content, chat.loading, currentUser]); - - function getNextMessage( - role: MessageRole, - response: Awaited> - ) { - const nextMessage: Message = { - '@timestamp': new Date().toISOString(), - message: { - role, - content: response.content, - ...omit(response, 'function_call'), - ...(response.function_call && response.function_call.name - ? { - function_call: { - ...response.function_call, - args: response.function_call.args - ? JSON.parse(response.function_call.args) - : undefined, - trigger: MessageRole.Assistant, - }, - } - : {}), - }, - }; + }, [conversationItems, pendingMessage, currentUser]); - return nextMessage; - } + useEffect(() => { + return () => { + // controllerRef.current.abort(); + subscription?.unsubscribe(); + }; + }, [subscription]); return { items, @@ -111,40 +177,22 @@ export function useTimeline({ const indexOf = items.indexOf(item); const messages = conversation.messages.slice(0, indexOf - 1); - - setConversation((conv) => ({ ...conv, messages })); - - chat - .generate({ - messages, - connectorId: connectors.selectedConnector!, - }) - .then((response) => { - setConversation((conv) => ({ - ...conv, - messages: conv.messages.concat(getNextMessage(MessageRole.Assistant, response)), - })); - }); + chat(messages); }, onStopGenerating: () => { - chat.abort(); + subscription?.unsubscribe(); + setPendingMessage((prevPendingMessage) => ({ + message: { + role: MessageRole.Assistant, + ...prevPendingMessage?.message, + }, + aborted: true, + error: new AbortError(), + })); + setSubscription(undefined); }, - onSubmit: async ({ content }) => { - if (connectorId) { - const nextMessage = getNextMessage(MessageRole.User, { content }); - - setConversation((conv) => ({ ...conv, messages: conv.messages.concat(nextMessage) })); - - const response = await chat.generate({ - messages: conversation.messages.concat(nextMessage), - connectorId, - }); - - setConversation((conv) => ({ - ...conv, - messages: conv.messages.concat(getNextMessage(MessageRole.Assistant, response)), - })); - } + onSubmit: async (message) => { + await chat(conversation.messages.concat(message)); }, }; } diff --git a/x-pack/plugins/observability_ai_assistant/public/plugin.tsx b/x-pack/plugins/observability_ai_assistant/public/plugin.tsx index d469ea3383772..949178a4b3649 100644 --- a/x-pack/plugins/observability_ai_assistant/public/plugin.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/plugin.tsx @@ -4,19 +4,27 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ +import { EuiCodeBlock } from '@elastic/eui'; import { - type AppMountParameters, AppNavLinkStatus, - type CoreSetup, + type CoreStart, DEFAULT_APP_CATEGORIES, + type AppMountParameters, + type CoreSetup, type Plugin, type PluginInitializerContext, - CoreStart, } from '@kbn/core/public'; import { i18n } from '@kbn/i18n'; import type { Logger } from '@kbn/logging'; +import type { Serializable } from '@kbn/utility-types'; import React from 'react'; import ReactDOM from 'react-dom'; +import type { + ContextRegistry, + FunctionRegistry, + RegisterContextDefinition, + RegisterFunctionDefinition, +} from '../common/types'; import { createService } from './service/create_service'; import type { ConfigSchema, @@ -94,6 +102,74 @@ export class ObservabilityAIAssistantPlugin coreStart: CoreStart, pluginsStart: ObservabilityAIAssistantPluginStartDependencies ): ObservabilityAIAssistantPluginStart { - return (this.service = createService({ coreStart, securityStart: pluginsStart.security })); + const contextRegistry: ContextRegistry = new Map(); + const functionRegistry: FunctionRegistry = new Map(); + + const service = (this.service = createService({ + coreStart, + securityStart: pluginsStart.security, + contextRegistry, + functionRegistry, + })); + + const registerContext: RegisterContextDefinition = (context) => { + contextRegistry.set(context.name, context); + }; + + const registerFunction: RegisterFunctionDefinition = (def, respond, render) => { + functionRegistry.set(def.name, { options: def, respond, render }); + }; + + registerContext({ + name: 'core', + description: + 'Core functions, like calling Elasticsearch APIs, storing embeddables for instructions or creating base visualisations.', + }); + + registerFunction( + { + name: 'elasticsearch', + contexts: ['core'], + description: 'Call Elasticsearch APIs on behalf of the user', + parameters: { + type: 'object', + properties: { + method: { + type: 'string', + description: 'The HTTP method of the Elasticsearch endpoint', + enum: ['GET', 'PUT', 'POST', 'DELETE', 'PATCH'] as const, + }, + path: { + type: 'string', + description: 'The path of the Elasticsearch endpoint, including query parameters', + }, + }, + required: ['method' as const, 'path' as const], + }, + }, + ({ arguments: { method, path, body } }, signal) => { + return service + .callApi(`POST /internal/observability_ai_assistant/functions/elasticsearch`, { + signal, + params: { + body: { + method, + path, + body, + }, + }, + }) + .then((response) => ({ content: response as Serializable })); + }, + ({ response: { content } }) => { + return {JSON.stringify(content, null, 2)}; + } + ); + + return { + ...service, + registerContext, + registerFunction, + }; } } diff --git a/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx b/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx index 61052b9456dc1..47c82a171af6f 100644 --- a/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx @@ -5,27 +5,37 @@ * 2.0. */ import { EuiFlexGroup, EuiFlexItem } from '@elastic/eui'; +import { css } from '@emotion/css'; import React from 'react'; import { ChatBody } from '../../components/chat/chat_body'; -import { useChat } from '../../hooks/use_chat'; import { useCurrentUser } from '../../hooks/use_current_user'; import { useGenAIConnectors } from '../../hooks/use_genai_connectors'; +import { useObservabilityAIAssistant } from '../../hooks/use_observability_ai_assistant'; + +const containerClassName = css` + max-width: 100%; +`; + +const chatBodyContainerClassName = css` + max-width: 100%; +`; export function ConversationView() { const connectors = useGenAIConnectors(); - const chat = useChat(); const currentUser = useCurrentUser(); + const service = useObservabilityAIAssistant(); + return ( - + - + diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_service.test.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_service.test.ts index 2da55a7a8f4fd..7129cf2d8f966 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_service.test.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_service.test.ts @@ -4,12 +4,13 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { CoreStart } from '@kbn/core/public'; +import type { CoreStart, HttpFetchOptions } from '@kbn/core/public'; import { ReadableStream } from 'stream/web'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; import type { ObservabilityAIAssistantService } from '../types'; import { createService } from './create_service'; import { SecurityPluginStart } from '@kbn/security-plugin/public'; +import { lastValueFrom } from 'rxjs'; describe('createService', () => { describe('chat', () => { @@ -17,14 +18,14 @@ describe('createService', () => { const httpPostSpy = jest.fn(); - function respondWithChunks({ chunks, status = 200 }: { status?: number; chunks: string[][] }) { + function respondWithChunks({ chunks, status = 200 }: { status?: number; chunks: string[] }) { const response = { response: { status, body: new ReadableStream({ start(controller) { chunks.forEach((chunk) => { - controller.enqueue(new TextEncoder().encode(chunk.join('\n'))); + controller.enqueue(new TextEncoder().encode(chunk)); }); controller.close(); }, @@ -35,10 +36,8 @@ describe('createService', () => { httpPostSpy.mockResolvedValueOnce(response); } - async function chat(signal: AbortSignal = new AbortController().signal) { - const response = await service.chat({ messages: [], connectorId: '', signal }); - - return response; + function chat() { + return service.chat({ messages: [], connectorId: '' }); } beforeEach(() => { @@ -53,6 +52,8 @@ describe('createService', () => { getCurrentUser: () => Promise.resolve({ username: 'elastic' } as AuthenticatedUser), }, } as unknown as SecurityPluginStart, + contextRegistry: new Map(), + functionRegistry: new Map(), }); }); @@ -61,56 +62,144 @@ describe('createService', () => { }); it('correctly parses a stream of JSON lines', async () => { - const chunk1 = ['data: {}', 'data: {}']; - const chunk2 = ['data: {}', 'data: [DONE]']; + const chunk1 = + 'data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}'; + const chunk2 = + '\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]'; respondWithChunks({ chunks: [chunk1, chunk2] }); - const response$ = await chat(); + const response$ = chat(); const results: any = []; - response$.subscribe({ + + const subscription = response$.subscribe({ next: (data) => results.push(data), complete: () => { - expect(results).toHaveLength(3); + expect(results).toHaveLength(4); + }, + }); + + const value = await lastValueFrom(response$); + subscription.unsubscribe(); + + expect(value).toEqual({ + message: { + role: 'assistant', + content: 'My new message', + function_call: { + arguments: '', + name: '', + trigger: 'assistant', + }, }, }); }); it('correctly buffers partial lines', async () => { - const chunk1 = ['data: {}', 'data: {']; - const chunk2 = ['}', 'data: [DONE]']; + const chunk1 = + 'data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"'; + const chunk2 = + '}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]'; respondWithChunks({ chunks: [chunk1, chunk2] }); - const response$ = await chat(); + const response$ = chat(); const results: any = []; - response$.subscribe({ - next: (data) => results.push(data), - complete: () => { - expect(results).toHaveLength(2); + + await new Promise((resolve, reject) => { + response$.subscribe({ + next: (data) => { + results.push(data); + }, + error: reject, + complete: resolve, + }); + }); + + const value = await lastValueFrom(response$); + + expect(results).toHaveLength(4); + + expect(value).toEqual({ + message: { + role: 'assistant', + content: 'My new message', + function_call: { + arguments: '', + name: '', + trigger: 'assistant', + }, }, }); }); - it('propagates invalid requests as an error', () => { + it('catches invalid requests and flags it as an error', async () => { respondWithChunks({ status: 400, chunks: [] }); - expect(() => chat()).rejects.toThrowErrorMatchingInlineSnapshot(`"Unexpected error"`); + const response$ = chat(); + + const value = await lastValueFrom(response$); + + expect(value).toEqual({ + aborted: false, + error: expect.any(Error), + message: { + role: 'assistant', + }, + }); }); it('propagates JSON parsing errors', async () => { - const chunk1 = ['data: {}', 'data: invalid json']; + respondWithChunks({ chunks: ['data: {}', 'data: invalid json'] }); + + const response$ = chat(); - respondWithChunks({ chunks: [chunk1] }); + const value = await lastValueFrom(response$); + + expect(value).toEqual({ + aborted: false, + error: expect.any(Error), + message: { + role: 'assistant', + }, + }); + }); + + it('cancels a running http request when aborted', async () => { + httpPostSpy.mockImplementationOnce((endpoint: string, options: HttpFetchOptions) => { + options.signal?.addEventListener('abort', () => { + expect(options.signal?.aborted).toBeTruthy(); + }); + return Promise.resolve({ + response: { + status: 200, + body: new ReadableStream({ + start(controller) {}, + }), + }, + }); + }); + + const response$ = chat(); + + await new Promise((resolve, reject) => { + const subscription = response$.subscribe({}); + + setTimeout(() => { + subscription.unsubscribe(); + resolve(); + }, 100); + }); - const response$ = await chat(); + const value = await lastValueFrom(response$); - response$.subscribe({ - error: (err) => { - expect(err).toBeInstanceOf(SyntaxError); + expect(value).toEqual({ + message: { + role: 'assistant', }, + aborted: true, }); }); }); diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts index a00352a39d0ce..4a79bee894675 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts @@ -6,67 +6,197 @@ */ import type { CoreStart, HttpResponse } from '@kbn/core/public'; +import { AbortError } from '@kbn/kibana-utils-plugin/common'; import { SecurityPluginStart } from '@kbn/security-plugin/public'; -import { filter, map } from 'rxjs'; +import { IncomingMessage } from 'http'; +import { cloneDeep } from 'lodash'; +import { + BehaviorSubject, + catchError, + concatMap, + delay, + filter as rxJsFilter, + finalize, + map, + of, + scan, + shareReplay, +} from 'rxjs'; import type { Message } from '../../common'; +import { ContextRegistry, FunctionRegistry, MessageRole } from '../../common/types'; import { createCallObservabilityAIAssistantAPI } from '../api'; -import type { CreateChatCompletionResponseChunk, ObservabilityAIAssistantService } from '../types'; +import type { + CreateChatCompletionResponseChunk, + ObservabilityAIAssistantService, + PendingMessage, +} from '../types'; import { readableStreamReaderIntoObservable } from '../utils/readable_stream_reader_into_observable'; export function createService({ coreStart, securityStart, + functionRegistry, + contextRegistry, }: { coreStart: CoreStart; securityStart: SecurityPluginStart; + functionRegistry: FunctionRegistry; + contextRegistry: ContextRegistry; }): ObservabilityAIAssistantService { const client = createCallObservabilityAIAssistantAPI(coreStart); + const getContexts: ObservabilityAIAssistantService['getContexts'] = () => { + return Array.from(contextRegistry.values()); + }; + const getFunctions: ObservabilityAIAssistantService['getFunctions'] = ({ + contexts, + filter, + } = {}) => { + const allFunctions = Array.from(functionRegistry.values()); + + return contexts || filter + ? allFunctions.filter((fn) => { + const matchesContext = + !contexts || fn.options.contexts.some((context) => contexts.includes(context)); + const matchesFilter = + !filter || fn.options.name.includes(filter) || fn.options.description.includes(filter); + + return matchesContext && matchesFilter; + }) + : allFunctions; + }; + return { isEnabled: () => { return true; }, - async chat({ - connectorId, - messages, - signal, - }: { - connectorId: string; - messages: Message[]; - signal: AbortSignal; - }) { - const response = (await client('POST /internal/observability_ai_assistant/chat', { + chat({ connectorId, messages }: { connectorId: string; messages: Message[] }) { + const subject = new BehaviorSubject({ + message: { + role: MessageRole.Assistant, + }, + }); + + const contexts = ['core']; + + const functions = getFunctions({ contexts }); + + const controller = new AbortController(); + + client('POST /internal/observability_ai_assistant/chat', { params: { body: { messages, connectorId, + functions: functions.map((fn) => fn.options), }, }, - signal, + signal: controller.signal, asResponse: true, rawResponse: true, - })) as unknown as HttpResponse; + }) + .then((_response) => { + const response = _response as unknown as HttpResponse; - const status = response.response?.status; + const status = response.response?.status; - if (!status || status >= 400) { - throw new Error(response.response?.statusText || 'Unexpected error'); - } + if (!status || status >= 400) { + throw new Error(response.response?.statusText || 'Unexpected error'); + } - const reader = response.response.body?.getReader(); + const reader = response.response.body?.getReader(); - if (!reader) { - throw new Error('Could not get reader from response'); - } + if (!reader) { + throw new Error('Could not get reader from response'); + } - return readableStreamReaderIntoObservable(reader).pipe( - map((line) => line.substring(6)), - filter((line) => !!line && line !== '[DONE]'), - map((line) => JSON.parse(line) as CreateChatCompletionResponseChunk), - filter((line) => line.object === 'chat.completion.chunk') + const subscription = readableStreamReaderIntoObservable(reader) + .pipe( + map((line) => line.substring(6)), + rxJsFilter((line) => !!line && line !== '[DONE]'), + map((line) => JSON.parse(line) as CreateChatCompletionResponseChunk), + rxJsFilter((line) => line.object === 'chat.completion.chunk'), + scan( + (acc, { choices }) => { + acc.message.content += choices[0].delta.content ?? ''; + acc.message.function_call.name += choices[0].delta.function_call?.name ?? ''; + acc.message.function_call.arguments += + choices[0].delta.function_call?.arguments ?? ''; + return cloneDeep(acc); + }, + { + message: { + content: '', + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant as const, + }, + role: MessageRole.Assistant, + }, + } + ), + catchError((error) => + of({ + ...subject.value, + error, + aborted: error instanceof AbortError || controller.signal.aborted, + }) + ) + ) + .subscribe(subject); + + controller.signal.addEventListener('abort', () => { + subscription.unsubscribe(); + subject.next({ + ...subject.value, + aborted: true, + }); + subject.complete(); + }); + }) + .catch((err) => { + subject.next({ + ...subject.value, + aborted: false, + error: err, + }); + subject.complete(); + }); + + return subject.pipe( + concatMap((value) => of(value).pipe(delay(50))), + shareReplay(1), + finalize(() => { + controller.abort(); + }) ); }, callApi: client, getCurrentUser: () => securityStart.authc.getCurrentUser(), + getContexts, + getFunctions, + executeFunction: async (name, args, signal) => { + const fn = functionRegistry.get(name); + + if (!fn) { + throw new Error(`Function ${name} not found`); + } + + const parsedArguments = args ? JSON.parse(args) : {}; + + // validate + + return await fn.respond({ arguments: parsedArguments }, signal); + }, + renderFunction: (name, response) => { + const fn = functionRegistry.get(name); + + if (!fn) { + throw new Error(`Function ${name} not found`); + } + + return fn.render?.({ response }); + }, }; } diff --git a/x-pack/plugins/observability_ai_assistant/public/types.ts b/x-pack/plugins/observability_ai_assistant/public/types.ts index bd51b6e63dbf6..61f74fc7735d7 100644 --- a/x-pack/plugins/observability_ai_assistant/public/types.ts +++ b/x-pack/plugins/observability_ai_assistant/public/types.ts @@ -4,6 +4,10 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ +import type { + ObservabilitySharedPluginSetup, + ObservabilitySharedPluginStart, +} from '@kbn/observability-shared-plugin/public'; import type { AuthenticatedUser, SecurityPluginSetup, @@ -13,16 +17,19 @@ import type { TriggersAndActionsUIPublicPluginSetup, TriggersAndActionsUIPublicPluginStart, } from '@kbn/triggers-actions-ui-plugin/public'; -import type { - ObservabilitySharedPluginSetup, - ObservabilitySharedPluginStart, -} from '@kbn/observability-shared-plugin/public'; import type { CreateChatCompletionResponse, CreateChatCompletionResponseChoicesInner, } from 'openai'; import type { Observable } from 'rxjs'; -import type { Message } from '../common/types'; +import { Serializable } from '@kbn/utility-types'; +import type { + ContextDefinition, + FunctionDefinition, + Message, + RegisterContextDefinition, + RegisterFunctionDefinition, +} from '../common/types'; import type { ObservabilityAIAssistantAPIClient } from './api'; /* eslint-disable @typescript-eslint/no-empty-interface*/ @@ -30,33 +37,49 @@ import type { ObservabilityAIAssistantAPIClient } from './api'; export type CreateChatCompletionResponseChunk = Omit & { choices: Array< Omit & { - delta: { content?: string; function_call?: { name?: string; args?: string } }; + delta: { content?: string; function_call?: { name?: string; arguments?: string } }; } >; }; +export interface PendingMessage { + message: Message['message']; + aborted?: boolean; + error?: any; +} + export interface ObservabilityAIAssistantService { isEnabled: () => boolean; - chat: (options: { - messages: Message[]; - connectorId: string; - signal: AbortSignal; - }) => Promise>; + chat: (options: { messages: Message[]; connectorId: string }) => Observable; callApi: ObservabilityAIAssistantAPIClient; getCurrentUser: () => Promise; + getContexts: () => ContextDefinition[]; + getFunctions: (options?: { contexts?: string[]; filter?: string }) => FunctionDefinition[]; + executeFunction: ( + name: string, + args: string | undefined, + signal: AbortSignal + ) => Promise<{ content?: Serializable; data?: Serializable }>; + renderFunction: ( + name: string, + response: { data?: Serializable; content?: Serializable } + ) => React.ReactNode; } -export interface ObservabilityAIAssistantPluginStart extends ObservabilityAIAssistantService {} +export interface ObservabilityAIAssistantPluginStart extends ObservabilityAIAssistantService { + registerContext: RegisterContextDefinition; + registerFunction: RegisterFunctionDefinition; +} export interface ObservabilityAIAssistantPluginSetup {} export interface ObservabilityAIAssistantPluginSetupDependencies { - triggersActions: TriggersAndActionsUIPublicPluginSetup; + triggersActionsUi: TriggersAndActionsUIPublicPluginSetup; security: SecurityPluginSetup; observabilityShared: ObservabilitySharedPluginSetup; } export interface ObservabilityAIAssistantPluginStartDependencies { security: SecurityPluginStart; - triggersActions: TriggersAndActionsUIPublicPluginStart; + triggersActionsUi: TriggersAndActionsUIPublicPluginStart; observabilityShared: ObservabilitySharedPluginStart; } diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts index 259a6c9f652ad..819734b8fdfab 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts +++ b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts @@ -65,7 +65,7 @@ export function buildFunctionInnerMessage(params: Omit ({ - id: v4(), - role: message.message.role, - title: message.message.role === MessageRole.System ? 'added a system prompt' : '', - content: - message.message.role === MessageRole.System ? undefined : message.message.content || '', - canEdit: - hasConnector && - (message.message.role === MessageRole.User || - message.message.role === MessageRole.Function), - canGiveFeedback: message.message.role === MessageRole.Assistant, - canRegenerate: hasConnector && message.message.role === MessageRole.Assistant, - loading: false, - currentUser, - })), + ...conversation.messages.map((message) => { + const hasFunction = !!message.message.function_call?.name; + const isSystemPrompt = message.message.role === MessageRole.System; + + let title: string; + if (hasFunction) { + title = i18n.translate('xpkac.observabilityAiAssistant.suggestedFunctionEvent', { + defaultMessage: 'suggested a function', + }); + } else if (isSystemPrompt) { + title = i18n.translate('xpack.observabilityAiAssistant.addedSystemPromptEvent', { + defaultMessage: 'added a prompt', + }); + } else { + title = ''; + } + + const props = { + id: v4(), + role: message.message.role, + canEdit: hasConnector && (message.message.role === MessageRole.User || hasFunction), + canRegenerate: hasConnector && message.message.role === MessageRole.Assistant, + canGiveFeedback: message.message.role === MessageRole.Assistant, + loading: false, + title, + content: message.message.content, + currentUser, + }; + + return props; + }), ]; } diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts index 5e0744a7f7238..30da7d10fed91 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts @@ -19,6 +19,14 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ body: t.type({ messages: t.array(messageRt), connectorId: t.string, + functions: t.array( + t.type({ + name: t.string, + description: t.string, + parameters: t.any, + contexts: t.array(t.string), + }) + ), }), }), handler: async (resources): Promise => { @@ -30,9 +38,14 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ throw notImplemented(); } + const { + body: { messages, connectorId, functions }, + } = params; + return client.chat({ - messages: params.body.messages, - connectorId: params.body.connectorId, + messages, + connectorId, + functions, }); }, }); diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/functions/route.ts b/x-pack/plugins/observability_ai_assistant/server/routes/functions/route.ts new file mode 100644 index 0000000000000..b2469fe7622e9 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/routes/functions/route.ts @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import * as t from 'io-ts'; +import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; + +const functionElasticsearchRoute = createObservabilityAIAssistantServerRoute({ + endpoint: 'POST /internal/observability_ai_assistant/functions/elasticsearch', + options: { + tags: ['access:ai_assistant'], + }, + params: t.type({ + body: t.intersection([ + t.type({ + method: t.union([ + t.literal('GET'), + t.literal('POST'), + t.literal('PATCH'), + t.literal('PUT'), + t.literal('DELETE'), + ]), + path: t.string, + }), + t.partial({ + body: t.any, + }), + ]), + }), + handler: async (resources): Promise => { + const { method, path, body } = resources.params.body; + + const response = await ( + await resources.context.core + ).elasticsearch.client.asCurrentUser.transport.request({ + method, + path, + body, + }); + + return response; + }, +}); + +export const functionRoutes = { + ...functionElasticsearchRoute, +}; diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/get_global_observability_ai_assistant_route_repository.ts b/x-pack/plugins/observability_ai_assistant/server/routes/get_global_observability_ai_assistant_route_repository.ts index fcca8df0f03e7..a01033137600b 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/get_global_observability_ai_assistant_route_repository.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/get_global_observability_ai_assistant_route_repository.ts @@ -8,12 +8,14 @@ import { chatRoutes } from './chat/route'; import { connectorRoutes } from './connectors/route'; import { conversationRoutes } from './conversations/route'; +import { functionRoutes } from './functions/route'; export function getGlobalObservabilityAIAssistantServerRouteRepository() { return { ...chatRoutes, ...conversationRoutes, ...connectorRoutes, + ...functionRoutes, }; } diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/runtime_types.ts b/x-pack/plugins/observability_ai_assistant/server/routes/runtime_types.ts index 7fb2ed98fa1a7..12b66ff988ac5 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/runtime_types.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/runtime_types.ts @@ -42,7 +42,7 @@ export const messageRt: t.Type = t.type({ ]), }), t.partial({ - args: serializeableRt, + arguments: serializeableRt, data: serializeableRt, }), ]), diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index 2afa7b9be3448..50eef7dbf2377 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -4,24 +4,28 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import { v4 } from 'uuid'; - -import type { ChatCompletionRequestMessage, CreateChatCompletionRequest } from 'openai'; -import type { IncomingMessage } from 'http'; +import type { SearchHit } from '@elastic/elasticsearch/lib/api/types'; +import { internal, notFound } from '@hapi/boom'; +import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client'; import type { ElasticsearchClient } from '@kbn/core/server'; import type { Logger } from '@kbn/logging'; -import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client'; +import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/gen_ai/constants'; import type { PublicMethodsOf } from '@kbn/utility-types'; -import { internal, notFound } from '@hapi/boom'; +import type { IncomingMessage } from 'http'; import { compact, isEmpty, merge, omit } from 'lodash'; -import type { SearchHit } from '@elastic/elasticsearch/lib/api/types'; -import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/gen_ai/constants'; +import type { + ChatCompletionFunctions, + ChatCompletionRequestMessage, + CreateChatCompletionRequest, +} from 'openai'; +import { v4 } from 'uuid'; import { + type FunctionDefinition, + MessageRole, type Conversation, type ConversationCreateRequest, type ConversationUpdateRequest, type Message, - MessageRole, } from '../../../common/types'; import type { IObservabilityAIAssistantClient, @@ -111,9 +115,11 @@ export class ObservabilityAIAssistantClient implements IObservabilityAIAssistant chat = async ({ messages, connectorId, + functions, }: { messages: Message[]; connectorId: string; + functions: Array; }): Promise => { const messagesForOpenAI: ChatCompletionRequestMessage[] = compact( messages @@ -133,6 +139,10 @@ export class ObservabilityAIAssistantClient implements IObservabilityAIAssistant }) ); + const functionsForOpenAI: ChatCompletionFunctions[] = functions.map((fn) => + omit(fn, 'contexts') + ); + const connector = await this.dependencies.actionsClient.get({ id: connectorId, }); @@ -141,6 +151,8 @@ export class ObservabilityAIAssistantClient implements IObservabilityAIAssistant ...(connector.config?.apiProvider === OpenAiProviderType.OpenAi ? { model: 'gpt-4' } : {}), messages: messagesForOpenAI, stream: true, + functions: functionsForOpenAI, + temperature: 0.1, }; const executeResult = await this.dependencies.actionsClient.execute({ diff --git a/x-pack/plugins/observability_ai_assistant/server/service/conversation_component_template.ts b/x-pack/plugins/observability_ai_assistant/server/service/conversation_component_template.ts index ce3a8d991e224..2ce8180b0fdc9 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/conversation_component_template.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/conversation_component_template.ts @@ -76,7 +76,7 @@ export const conversationComponentTemplate: ClusterComponentTemplate['component_ type: 'object', properties: { name: keyword, - args: { + arguments: { type: 'object', enabled: false, }, diff --git a/x-pack/plugins/observability_ai_assistant/server/service/types.ts b/x-pack/plugins/observability_ai_assistant/server/service/types.ts index d8f0c6b2b5f09..56824d5506c18 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/types.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/types.ts @@ -7,15 +7,20 @@ import { IncomingMessage } from 'http'; import { KibanaRequest } from '@kbn/core/server'; -import { +import type { Conversation, ConversationCreateRequest, ConversationUpdateRequest, + FunctionDefinition, Message, } from '../../common/types'; export interface IObservabilityAIAssistantClient { - chat: (options: { messages: Message[]; connectorId: string }) => Promise; + chat: (options: { + messages: Message[]; + connectorId: string; + functions: Array; + }) => Promise; get: (conversationId: string) => void; find: (options?: { query?: string }) => Promise<{ conversations: Conversation[] }>; create: (conversation: ConversationCreateRequest) => Promise; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts index 29214d18709bd..fde53a53803ac 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gen_ai/gen_ai.ts @@ -82,6 +82,7 @@ export class GenAiConnector extends SubActionConnector