Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai): Make response parsing extensible #14196

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/ai-chat/src/browser/ai-chat-frontend-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import { UniversalChatAgent } from '../common/universal-chat-agent';
import { aiChatPreferences } from './ai-chat-preferences';
import { ChatAgentsVariableContribution } from '../common/chat-agents-variable-contribution';
import { FrontendChatServiceImpl } from './frontend-chat-service';
import { DefaultResponseContentMatcherProvider, DefaultResponseContentFactory, ResponseContentMatcherProvider } from '../common/response-content-matcher';

export default new ContainerModule(bind => {
bindContributionProvider(bind, Agent);
Expand All @@ -42,6 +43,11 @@ export default new ContainerModule(bind => {
bind(ChatAgentService).toService(ChatAgentServiceImpl);
bind(DefaultChatAgentId).toConstantValue({ id: OrchestratorChatAgentId });

bindContributionProvider(bind, ResponseContentMatcherProvider);
bind(DefaultResponseContentMatcherProvider).toSelf().inSingletonScope();
bind(ResponseContentMatcherProvider).toService(DefaultResponseContentMatcherProvider);
bind(DefaultResponseContentFactory).toSelf().inSingletonScope();

bind(AIVariableContribution).to(ChatAgentsVariableContribution).inSingletonScope();

bind(ChatRequestParserImpl).toSelf().inSingletonScope();
Expand Down
119 changes: 59 additions & 60 deletions packages/ai-chat/src/common/chat-agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
LanguageModel,
LanguageModelRequirement,
LanguageModelResponse,
LanguageModelStreamResponse,
PromptService,
ResolvedPromptTemplate,
ToolRequest,
Expand All @@ -37,19 +38,20 @@ import {
LanguageModelStreamResponsePart,
MessageActor,
} from '@theia/ai-core/lib/common';
import { CancellationToken, CancellationTokenSource, ILogger, isArray } from '@theia/core';
import { inject, injectable } from '@theia/core/shared/inversify';
import { CancellationToken, CancellationTokenSource, ContributionProvider, ILogger, isArray } from '@theia/core';
import { inject, injectable, named, postConstruct } from '@theia/core/shared/inversify';
import { ChatAgentService } from './chat-agent-service';
import {
ChatModel,
ChatRequestModel,
ChatRequestModelImpl,
ChatResponseContent,
CodeChatResponseContentImpl,
ErrorChatResponseContentImpl,
MarkdownChatResponseContentImpl,
ToolCallChatResponseContentImpl
} from './chat-model';
import { findFirstMatch, parseContents } from './parse-contents';
import { DefaultResponseContentFactory, ResponseContentMatcher, ResponseContentMatcherProvider } from './response-content-matcher';

/**
* A conversation consists of a sequence of ChatMessages.
Expand Down Expand Up @@ -121,6 +123,14 @@ export abstract class AbstractChatAgent {
@inject(ILogger) protected logger: ILogger;
@inject(CommunicationRecordingService) protected recordingService: CommunicationRecordingService;
@inject(PromptService) protected promptService: PromptService;

@inject(ContributionProvider) @named(ResponseContentMatcherProvider)
protected contentMatcherProviders: ContributionProvider<ResponseContentMatcherProvider>;
protected contentMatchers: ResponseContentMatcher[] = [];

@inject(DefaultResponseContentFactory)
protected defaultContentFactory: DefaultResponseContentFactory;

constructor(
public id: string,
public languageModelRequirements: LanguageModelRequirement[],
Expand All @@ -130,6 +140,11 @@ export abstract class AbstractChatAgent {
public tags: String[] = ['Chat']) {
}

@postConstruct()
init(): void {
this.contentMatchers = this.contentMatcherProviders.getContributions().flatMap(provider => provider.matchers);
}

async invoke(request: ChatRequestModelImpl): Promise<void> {
try {
const languageModel = await this.getLanguageModel(this.defaultLanguageModelPurpose);
Expand Down Expand Up @@ -189,6 +204,14 @@ export abstract class AbstractChatAgent {
}
}

protected parseContents(text: string): ChatResponseContent[] {
return parseContents(
text,
this.contentMatchers,
this.defaultContentFactory?.create.bind(this.defaultContentFactory)
);
};

protected handleError(request: ChatRequestModelImpl, error: Error): void {
request.response.response.addContent(new ErrorChatResponseContentImpl(error));
request.response.error(error);
Expand Down Expand Up @@ -281,9 +304,8 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {

protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: ChatRequestModelImpl): Promise<void> {
if (isLanguageModelTextResponse(languageModelResponse)) {
request.response.response.addContent(
new MarkdownChatResponseContentImpl(languageModelResponse.text)
);
const contents = this.parseContents(languageModelResponse.text);
request.response.response.addContents(contents);
request.response.complete();
this.recordingService.recordResponse({
agentId: this.id,
Expand All @@ -295,57 +317,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
return;
}
if (isLanguageModelStreamResponse(languageModelResponse)) {
for await (const token of languageModelResponse.stream) {
const newContents = this.parse(token, request.response.response.content);
if (isArray(newContents)) {
newContents.forEach(newContent => request.response.response.addContent(newContent));
} else {
request.response.response.addContent(newContents);
}

const lastContent = request.response.response.content.pop();
if (lastContent === undefined) {
return;
}
const text = lastContent.asString?.();
if (text === undefined) {
return;
}
let curSearchIndex = 0;
const result: ChatResponseContent[] = [];
while (curSearchIndex < text.length) {
// find start of code block: ```[language]\n<code>[\n]```
const codeStartIndex = text.indexOf('```', curSearchIndex);
if (codeStartIndex === -1) {
break;
}

// find language specifier if present
const newLineIndex = text.indexOf('\n', codeStartIndex + 3);
const language = codeStartIndex + 3 < newLineIndex ? text.substring(codeStartIndex + 3, newLineIndex) : undefined;

// find end of code block
const codeEndIndex = text.indexOf('```', codeStartIndex + 3);
if (codeEndIndex === -1) {
break;
}

// add text before code block as markdown content
result.push(new MarkdownChatResponseContentImpl(text.substring(curSearchIndex, codeStartIndex)));
// add code block as code content
const codeText = text.substring(newLineIndex + 1, codeEndIndex).trimEnd();
result.push(new CodeChatResponseContentImpl(codeText, language));
curSearchIndex = codeEndIndex + 3;
}

if (result.length > 0) {
result.forEach(r => {
request.response.response.addContent(r);
});
} else {
request.response.response.addContent(lastContent);
}
}
await this.addStreamResponse(languageModelResponse, request);
request.response.complete();
this.recordingService.recordResponse({
agentId: this.id,
Expand All @@ -366,19 +338,46 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
);
}

private parse(token: LanguageModelStreamResponsePart, previousContent: ChatResponseContent[]): ChatResponseContent | ChatResponseContent[] {
protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: ChatRequestModelImpl): Promise<void> {
for await (const token of languageModelResponse.stream) {
const newContents = this.parse(token, request.response.response.content);
if (isArray(newContents)) {
request.response.response.addContents(newContents);
} else {
request.response.response.addContent(newContents);
}

const lastContent = request.response.response.content.pop();
if (lastContent === undefined) {
return;
}
const text = lastContent.asString?.();
if (text === undefined) {
return;
}

const result: ChatResponseContent[] = findFirstMatch(this.contentMatchers, text) ? this.parseContents(text) : [];
if (result.length > 0) {
request.response.response.addContents(result);
} else {
request.response.response.addContent(lastContent);
}
}
}

protected parse(token: LanguageModelStreamResponsePart, previousContent: ChatResponseContent[]): ChatResponseContent | ChatResponseContent[] {
const content = token.content;
// eslint-disable-next-line no-null/no-null
if (content !== undefined && content !== null) {
return new MarkdownChatResponseContentImpl(content);
return this.defaultContentFactory.create(content);
}
const toolCalls = token.tool_calls;
if (toolCalls !== undefined) {
const toolCallContents = toolCalls.map(toolCall =>
new ToolCallChatResponseContentImpl(toolCall.id, toolCall.function?.name, toolCall.function?.arguments, toolCall.finished, toolCall.result));
return toolCallContents;
}
return new MarkdownChatResponseContentImpl('');
return this.defaultContentFactory.create('');
}

}
18 changes: 13 additions & 5 deletions packages/ai-chat/src/common/chat-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,20 @@ class ChatResponseImpl implements ChatResponse {
return this._content;
}

addContents(contents: ChatResponseContent[]): void {
contents.forEach(c => this.doAddContent(c));
this._onDidChangeEmitter.fire();
}

addContent(nextContent: ChatResponseContent): void {
// TODO: Support more complex merges affecting different content than the last, e.g. via some kind of ProcessorRegistry
// TODO: Support more of the built-in VS Code behavior, see
// https://github.com/microsoft/vscode/blob/a2cab7255c0df424027be05d58e1b7b941f4ea60/src/vs/workbench/contrib/chat/common/chatModel.ts#L188-L244
this.doAddContent(nextContent);
this._onDidChangeEmitter.fire();
}

protected doAddContent(nextContent: ChatResponseContent): void {
if (ToolCallChatResponseContent.is(nextContent) && nextContent.id !== undefined) {
const fittingTool = this._content.find(c => ToolCallChatResponseContent.is(c) && c.id === nextContent.id);
if (fittingTool !== undefined) {
Expand All @@ -613,10 +623,9 @@ class ChatResponseImpl implements ChatResponse {
this._content.push(nextContent);
}
} else {
const lastElement =
this._content.length > 0
? this._content[this._content.length - 1]
: undefined;
const lastElement = this._content.length > 0
? this._content[this._content.length - 1]
: undefined;
if (lastElement?.kind === nextContent.kind && ChatResponseContent.hasMerge(lastElement)) {
const mergeSuccess = lastElement.merge(nextContent);
if (!mergeSuccess) {
Expand All @@ -627,7 +636,6 @@ class ChatResponseImpl implements ChatResponse {
}
}
this._updateResponseRepresentation();
this._onDidChangeEmitter.fire();
}

protected _updateResponseRepresentation(): void {
Expand Down
Loading
Loading