Skip to content

Commit

Permalink
onToolCall for useChat / React (#851)
Browse files Browse the repository at this point in the history
Co-authored-by: Max Leiter <maxwell.leiter@gmail.com>
Co-authored-by: Lars Grammel <lgrammel@Larss-MBP.fritz.box>
  • Loading branch information
3 people authored Jan 2, 2024
1 parent 9b89c4d commit 75751c9
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 42 deletions.
5 changes: 5 additions & 0 deletions .changeset/slow-students-roll.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

ai/react: Add experimental_onToolCall to useChat.
122 changes: 122 additions & 0 deletions examples/next-openai/app/api/chat-with-tools/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import {
OpenAIStream,
StreamingTextResponse,
Tool,
ToolCallPayload,
experimental_StreamData,
} from 'ai';
import OpenAI from 'openai';

// Create an OpenAI API client (that's edge friendly!)
const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY || '',
});

// IMPORTANT! Set the runtime to edge
export const runtime = 'edge';

const tools: Tool[] = [
{
type: 'function',
function: {
name: 'get_current_weather',
description: 'Get the current weather',
parameters: {
type: 'object',
properties: {
location: {
type: 'string',
description: 'The city and state, e.g. San Francisco, CA',
},
format: {
type: 'string',
enum: ['celsius', 'fahrenheit'],
description:
'The temperature unit to use. Infer this from the users location.',
},
},
required: ['location', 'format'],
},
},
},
{
type: 'function',
function: {
name: 'eval_code_in_browser',
description: 'Execute javascript code in the browser with eval().',
parameters: {
type: 'object',
properties: {
code: {
type: 'string',
description: `Javascript code that will be directly executed via eval(). Do not use backticks in your response.
DO NOT include any newlines in your response, and be sure to provide only valid JSON when providing the arguments object.
The output of the eval() will be returned directly by the function.`,
},
},
required: ['code'],
},
},
},
];

export async function POST(req: Request) {
const { messages } = await req.json();

const model = 'gpt-3.5-turbo-0613';

const response = await openai.chat.completions.create({
model,
stream: true,
messages,
tools,
tool_choice: 'auto',
});

const data = new experimental_StreamData();
const stream = OpenAIStream(response, {
experimental_onToolCall: async (
call: ToolCallPayload,
appendToolCallMessage,
) => {
for (const toolCall of call.tools) {
// Note: this is a very simple example of a tool call handler
// that only supports a single tool call function.
if (toolCall.func.name === 'get_current_weather') {
// Call a weather API here
const weatherData = {
temperature: 20,
unit: toolCall.func.arguments.format === 'celsius' ? 'C' : 'F',
};

const newMessages = appendToolCallMessage({
tool_call_id: toolCall.id,
function_name: 'get_current_weather',
tool_call_result: weatherData,
});

return openai.chat.completions.create({
messages: [...messages, ...newMessages],
model,
stream: true,
tools,
tool_choice: 'auto',
});
}
}
},
onCompletion(completion) {
console.log('completion', completion);
},
onFinal(completion) {
data.close();
},
experimental_streamData: true,
});

data.append({
text: 'Hello, how are you?',
});

return new StreamingTextResponse(stream, {}, data);
}
3 changes: 3 additions & 0 deletions examples/next-openai/app/function-calling/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ export default function Chat() {
const parsedFunctionCallArguments: { code: string } = JSON.parse(
functionCall.arguments,
);

// WARNING: Do NOT do this in real-world applications!
eval(parsedFunctionCallArguments.code);

const functionResponse = {
messages: [
...chatMessages,
Expand All @@ -27,6 +29,7 @@ export default function Chat() {
},
],
};

return functionResponse;
}
}
Expand Down
86 changes: 86 additions & 0 deletions examples/next-openai/app/tool-calling/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
'use client';

import { ChatRequest, ToolCallHandler, nanoid } from 'ai';
import { Message, useChat } from 'ai/react';

export default function Chat() {
const toolCallHandler: ToolCallHandler = async (chatMessages, toolCalls) => {
let handledFunction = false;
for (const tool of toolCalls) {
if (tool.type === 'function') {
const { name, arguments: args } = tool.function;

if (name === 'eval_code_in_browser') {
// Parsing here does not always work since it seems that some characters in generated code aren't escaped properly.
const parsedFunctionCallArguments: { code: string } =
JSON.parse(args);

// WARNING: Do NOT do this in real-world applications!
eval(parsedFunctionCallArguments.code);

const result = parsedFunctionCallArguments.code;

if (result) {
handledFunction = true;

chatMessages.push({
id: nanoid(),
tool_call_id: tool.id,
name: tool.function.name,
role: 'tool' as const,
content: result,
});
}
}
}
}

if (handledFunction) {
const toolResponse: ChatRequest = { messages: chatMessages };
return toolResponse;
}
};

const { messages, input, handleInputChange, handleSubmit } = useChat({
api: '/api/chat-with-tools',
experimental_onToolCall: toolCallHandler,
});

// Generate a map of message role to text color
const roleToColorMap: Record<Message['role'], string> = {
system: 'red',
user: 'black',
function: 'blue',
tool: 'purple',
assistant: 'green',
data: 'orange',
};

return (
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
{messages.length > 0
? messages.map((m: Message) => (
<div
key={m.id}
className="whitespace-pre-wrap"
style={{ color: roleToColorMap[m.role] }}
>
<strong>{`${m.role}: `}</strong>
{m.content || JSON.stringify(m.function_call)}
<br />
<br />
</div>
))
: null}
<div id="chart-goes-here"></div>
<form onSubmit={handleSubmit}>
<input
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
value={input}
placeholder="Say something..."
onChange={handleInputChange}
/>
</form>
</div>
);
}
56 changes: 23 additions & 33 deletions examples/sveltekit-openai/src/routes/chat-with-tools/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,39 @@
let handledFunction = false;
for (const tool of toolCalls) {
if (tool.type === 'function') {
const result = handleFunction(tool.function);
if (result) {
handledFunction = true;
chatMessages.push({
id: nanoid(),
tool_call_id: tool.id,
name: tool.function.name,
role: 'tool' as const,
content: result,
});
}
}
}
if (handledFunction) {
const toolResponse: ChatRequest = {
messages: chatMessages,
};
return toolResponse;
}
const { name, arguments: args } = tool.function;
function handleFunction({
name,
arguments: args,
}: {
name: string;
arguments: string;
}): string | undefined {
if (name === 'eval_code_in_browser' && args) {
// Use try-catch to account for invalid JSON generated by the LLM.
try {
if (name === 'eval_code_in_browser') {
// Parsing here does not always work since it seems that some characters in generated code aren't escaped properly.
const parsedFunctionCallArguments: { code: string } =
JSON.parse(args);
// WARNING: Do NOT do this in real-world applications!
eval(parsedFunctionCallArguments.code);
return parsedFunctionCallArguments.code;
} catch (e) {
return `Error: ${e}`;
const result = parsedFunctionCallArguments.code;
if (result) {
handledFunction = true;
chatMessages.push({
id: nanoid(),
tool_call_id: tool.id,
name: tool.function.name,
role: 'tool' as const,
content: result,
});
}
}
}
}
if (handledFunction) {
const toolResponse: ChatRequest = {
messages: chatMessages,
};
return toolResponse;
}
};
const { messages, input, handleSubmit } = useChat({
Expand Down
Loading

0 comments on commit 75751c9

Please sign in to comment.