Skip to content

Commit

Permalink
fix (ai/ui): tool call streaming (#2345)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jul 19, 2024
1 parent a44a8f3 commit 5b7b3bb
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 12 deletions.
6 changes: 6 additions & 0 deletions .changeset/eleven-zoos-dream.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@ai-sdk/ui-utils': patch
'@ai-sdk/react': patch
---

fix (ai/ui): tool call streaming
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { openai } from '@ai-sdk/openai';
import { convertToCoreMessages, streamText } from 'ai';
import { z } from 'zod';

// Allow streaming responses up to 30 seconds
export const maxDuration = 30;

export async function POST(req: Request) {
const { messages } = await req.json();

const result = await streamText({
model: openai('gpt-4-turbo'),
messages: convertToCoreMessages(messages),
experimental_toolCallStreaming: true,
system:
'You are a helpful assistant that answers questions about the weather in a given city.' +
'You use the showWeatherInformation tool to show the weather information to the user instead of talking about it.',
tools: {
// server-side tool with execute function:
getWeatherInformation: {
description: 'show the weather in a given city to the user',
parameters: z.object({ city: z.string() }),
execute: async ({}: { city: string }) => {
const weatherOptions = ['sunny', 'cloudy', 'rainy', 'snowy', 'windy'];
return {
weather:
weatherOptions[Math.floor(Math.random() * weatherOptions.length)],
temperature: Math.floor(Math.random() * 50 - 10),
};
},
},
// client-side tool that displays whether information to the user:
showWeatherInformation: {
description:
'Show the weather information to the user. Always use this tool to tell weather information to the user.',
parameters: z.object({
city: z.string(),
weather: z.string(),
temperature: z.number(),
typicalWeather: z
.string()
.describe(
'2-3 sentences about the typical weather in the city during spring.',
),
}),
},
},
});

return result.toAIStreamResponse();
}
69 changes: 69 additions & 0 deletions examples/next-openai/app/use-chat-streaming-tool-calls/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
'use client';

import { ToolInvocation } from 'ai';
import { Message, useChat } from 'ai/react';

export default function Chat() {
const { messages, input, handleInputChange, handleSubmit } = useChat({
api: '/api/use-chat-streaming-tool-calls',
maxToolRoundtrips: 5,

// run client-side tools that are automatically executed:
async onToolCall({ toolCall }) {
if (toolCall.toolName === 'showWeatherInformation') {
// display tool. add tool result that informs the llm that the tool was executed.
return 'Weather information was shown to the user.';
}
},
});

// used to only render the role when it changes:
let lastRole: string | undefined = undefined;

return (
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
{messages?.map((m: Message) => {
const isNewRole = m.role !== lastRole;
lastRole = m.role;

return (
<div key={m.id} className="whitespace-pre-wrap">
{isNewRole && <strong>{`${m.role}: `}</strong>}
{m.content}
{m.toolInvocations?.map((toolInvocation: ToolInvocation) => {
const { toolCallId, args } = toolInvocation;

// render display weather tool calls:
if (toolInvocation.toolName === 'showWeatherInformation') {
return (
<div
key={toolCallId}
className="p-4 my-2 text-gray-500 border border-gray-300 rounded"
>
<h4 className="mb-2">{args?.city ?? ''}</h4>
<div className="flex flex-col gap-2">
<div className="flex gap-2">
{args?.weather && <b>{args.weather}</b>}
{args?.temperature && <b>{args.temperature} &deg;C</b>}
</div>
{args?.typicalWeather && <div>{args.typicalWeather}</div>}
</div>
</div>
);
}
})}
</div>
);
})}

<form onSubmit={handleSubmit}>
<input
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
value={input}
placeholder="Say something..."
onChange={handleInputChange}
/>
</form>
</div>
);
}
2 changes: 1 addition & 1 deletion packages/react/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"dependencies": {
"@ai-sdk/provider-utils": "1.0.2",
"@ai-sdk/ui-utils": "0.0.15",
"swr": "2.2.0"
"swr": "2.2.5"
},
"devDependencies": {
"@testing-library/jest-dom": "^6.4.5",
Expand Down
8 changes: 1 addition & 7 deletions packages/react/src/use-chat.ui.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,6 @@ describe('onToolCall', () => {
});

describe('tool invocations', () => {
let rerender: RenderResult['rerender'];

const TestComponent = () => {
const { messages, append } = useChat();

Expand Down Expand Up @@ -485,8 +483,7 @@ describe('tool invocations', () => {
};

beforeEach(() => {
const result = render(<TestComponent />);
rerender = result.rerender;
render(<TestComponent />);
});

afterEach(() => {
Expand Down Expand Up @@ -522,7 +519,6 @@ describe('tool invocations', () => {
);

await waitFor(() => {
rerender(<TestComponent />);
expect(screen.getByTestId('message-1')).toHaveTextContent(
'{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"t"}}',
);
Expand All @@ -536,7 +532,6 @@ describe('tool invocations', () => {
);

await waitFor(() => {
rerender(<TestComponent />);
expect(screen.getByTestId('message-1')).toHaveTextContent(
'{"state":"partial-call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}',
);
Expand All @@ -551,7 +546,6 @@ describe('tool invocations', () => {
);

await waitFor(() => {
rerender(<TestComponent />);
expect(screen.getByTestId('message-1')).toHaveTextContent(
'{"state":"call","toolCallId":"tool-call-0","toolName":"test-tool","args":{"testArg":"test-value"}}',
);
Expand Down
8 changes: 8 additions & 0 deletions packages/ui-utils/src/parse-complex-response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ export async function parseComplexResponse({
toolName: partialToolCall.toolName,
args: parsePartialJson(partialToolCall.text),
};

// trigger update for streaming by copying adding a update id that changes
// (without it, the changes get stuck in SWR and are not forwarded to rendering):
(prefixMap.text! as any).internalUpdateId = generateId();
} else if (type === 'tool_call') {
if (partialToolCalls[value.toolCallId] != null) {
// change the partial tool call to a full tool call
Expand Down Expand Up @@ -152,6 +156,10 @@ export async function parseComplexResponse({
});
}

// trigger update for streaming by copying adding a update id that changes
// (without it, the changes get stuck in SWR and are not forwarded to rendering):
(prefixMap.text! as any).internalUpdateId = generateId();

// invoke the onToolCall callback if it exists. This is blocking.
// In the future we should make this non-blocking, which
// requires additional state management for error handling etc.
Expand Down
9 changes: 5 additions & 4 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5b7b3bb

Please sign in to comment.