diff --git a/.changeset/itchy-pumpkins-punch.md b/.changeset/itchy-pumpkins-punch.md new file mode 100644 index 000000000000..489cbd958a3e --- /dev/null +++ b/.changeset/itchy-pumpkins-punch.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat (core): add extractReasoningMiddleware diff --git a/content/docs/03-ai-sdk-core/45-middleware.mdx b/content/docs/03-ai-sdk-core/45-middleware.mdx index 187ed287da2b..7fe88e056ab5 100644 --- a/content/docs/03-ai-sdk-core/45-middleware.mdx +++ b/content/docs/03-ai-sdk-core/45-middleware.mdx @@ -40,6 +40,29 @@ const result = streamText({ }); ``` +## Built-in Middleware + +### Extract Reasoning + +Some providers and models expose reasoning information in the generated text using special tags, +e.g. <think> and </think>. + +The `extractReasoningMiddleware` function can be used to extract this reasoning information and expose it as a `reasoning` property on the result. + +```ts +import { + experimental_wrapLanguageModel as wrapLanguageModel, + extractReasoningMiddleware, +} from 'ai'; + +const model = wrapLanguageModel({ + model: yourModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), +}); +``` + +You can then use that enhanced model in functions like `generateText` and `streamText`. + ## Implementing Language Model Middleware diff --git a/content/docs/07-reference/01-ai-sdk-core/66-extract-reasoning-middleware.mdx b/content/docs/07-reference/01-ai-sdk-core/66-extract-reasoning-middleware.mdx new file mode 100644 index 000000000000..1d9fdd5218b6 --- /dev/null +++ b/content/docs/07-reference/01-ai-sdk-core/66-extract-reasoning-middleware.mdx @@ -0,0 +1,61 @@ +--- +title: extractReasoningMiddleware +description: Middleware that extracts XML-tagged reasoning sections from generated text +--- + +# `extractReasoningMiddleware()` + +`extractReasoningMiddleware` is a middleware function that extracts XML-tagged reasoning sections from generated text and exposes them separately from the main text content. This is particularly useful when you want to separate an AI model's reasoning process from its final output. + +```ts +import { extractReasoningMiddleware } from 'ai'; + +const middleware = extractReasoningMiddleware({ + tagName: 'reasoning', + separator: '\n', +}); +``` + +## Import + + + +## API Signature + +### Parameters + + + +### Returns + +Returns a middleware object that: + +- Processes both streaming and non-streaming responses +- Extracts content between specified XML tags as reasoning +- Removes the XML tags and reasoning from the main text +- Adds a `reasoning` property to the result containing the extracted content +- Maintains proper separation between text sections using the specified separator + +### Type Parameters + +The middleware works with the `LanguageModelV1StreamPart` type for streaming responses. diff --git a/content/docs/07-reference/01-ai-sdk-core/index.mdx b/content/docs/07-reference/01-ai-sdk-core/index.mdx index 00ec54f5dd56..33e948f23e32 100644 --- a/content/docs/07-reference/01-ai-sdk-core/index.mdx +++ b/content/docs/07-reference/01-ai-sdk-core/index.mdx @@ -87,6 +87,17 @@ It also contains the following helper functions: 'Creates a ReadableStream that emits values with configurable delays.', href: '/docs/reference/ai-sdk-core/simulate-readable-stream', }, + { + title: 'wrapLanguageModel()', + description: 'Wraps a language model with middleware.', + href: '/docs/reference/ai-sdk-core/wrap-language-model', + }, + { + title: 'extractReasoningMiddleware()', + description: + 'Extracts reasoning from the generated text and exposes it as a `reasoning` property on the result.', + href: '/docs/reference/ai-sdk-core/extract-reasoning-middleware', + }, { title: 'smoothStream()', description: 'Smooths text streaming output.', diff --git a/examples/ai-core/src/stream-text/deepseek-fullstream.ts b/examples/ai-core/src/stream-text/deepseek-reasoning-fullstream.ts similarity index 100% rename from examples/ai-core/src/stream-text/deepseek-fullstream.ts rename to examples/ai-core/src/stream-text/deepseek-reasoning-fullstream.ts diff --git a/examples/ai-core/src/stream-text/groq-reasoning-fullstream.ts b/examples/ai-core/src/stream-text/groq-reasoning-fullstream.ts new file mode 100644 index 000000000000..ce3303984b54 --- /dev/null +++ b/examples/ai-core/src/stream-text/groq-reasoning-fullstream.ts @@ -0,0 +1,45 @@ +import { groq } from '@ai-sdk/groq'; +import { + experimental_wrapLanguageModel, + extractReasoningMiddleware, + streamText, +} from 'ai'; +import 'dotenv/config'; + +async function main() { + const result = streamText({ + model: experimental_wrapLanguageModel({ + model: groq('deepseek-r1-distill-llama-70b'), + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Invent a new holiday and describe its traditions.', + }); + + let enteredReasoning = false; + let enteredText = false; + for await (const part of result.fullStream) { + if (part.type === 'reasoning') { + if (!enteredReasoning) { + enteredReasoning = true; + console.log('\nSTREAMING REASONING:\n'); + } + process.stdout.write(part.textDelta); + } else if (part.type === 'text-delta') { + if (!enteredText) { + enteredText = true; + console.log('\nSTREAMING TEXT:\n'); + } + process.stdout.write(part.textDelta); + } + } + + console.log(); + console.log('\nFINAL REASONING:\n', await result.reasoning); + console.log('\nFINAL TEXT:\n', await result.text); + + console.log(); + console.log('Token usage:', await result.usage); + console.log('Finish reason:', await result.finishReason); +} + +main().catch(console.error); diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts new file mode 100644 index 000000000000..1fbfddee26c0 --- /dev/null +++ b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts @@ -0,0 +1,272 @@ +import { + convertArrayToReadableStream, + convertAsyncIterableToArray, +} from '@ai-sdk/provider-utils/test'; +import { generateText, streamText } from '../generate-text'; +import { experimental_wrapLanguageModel } from '../middleware/wrap-language-model'; +import { mockId } from '../test/mock-id'; +import { MockLanguageModelV1 } from '../test/mock-language-model-v1'; +import { extractReasoningMiddleware } from './extract-reasoning-middleware'; + +describe('extractReasoningMiddleware', () => { + describe('wrapGenerate', () => { + it('should extract reasoning from tags', async () => { + const mockModel = new MockLanguageModelV1({ + async doGenerate() { + return { + text: 'analyzing the requestHere is the response', + finishReason: 'stop', + usage: { promptTokens: 10, completionTokens: 10 }, + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = await generateText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + }); + + expect(result.reasoning).toStrictEqual('analyzing the request'); + expect(result.text).toStrictEqual('Here is the response'); + }); + + it('should extract reasoning from multiple tags', async () => { + const mockModel = new MockLanguageModelV1({ + async doGenerate() { + return { + text: 'analyzing the requestHere is the responsethinking about the responsemore', + finishReason: 'stop', + usage: { promptTokens: 10, completionTokens: 10 }, + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = await generateText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + }); + + expect(result.reasoning).toStrictEqual( + 'analyzing the request\nthinking about the response', + ); + expect(result.text).toStrictEqual('Here is the response\nmore'); + }); + }); + + describe('wrapStream', () => { + it('should extract reasoning from split tags', async () => { + const mockModel = new MockLanguageModelV1({ + async doStream() { + return { + stream: convertArrayToReadableStream([ + { + type: 'response-metadata', + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + { type: 'text-delta', textDelta: 'ana' }, + { type: 'text-delta', textDelta: 'lyzing the request' }, + { type: 'text-delta', textDelta: 'Here' }, + { type: 'text-delta', textDelta: ' is the response' }, + { + type: 'finish', + finishReason: 'stop', + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ]), + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = streamText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + experimental_generateMessageId: mockId({ prefix: 'msg' }), + }); + + expect( + await convertAsyncIterableToArray(result.fullStream), + ).toStrictEqual([ + { + messageId: 'msg-0', + request: {}, + type: 'step-start', + warnings: [], + }, + { + type: 'reasoning', + textDelta: 'ana', + }, + { + type: 'reasoning', + textDelta: 'lyzing the request', + }, + { + type: 'text-delta', + textDelta: 'Here', + }, + { + type: 'text-delta', + textDelta: ' is the response', + }, + { + experimental_providerMetadata: undefined, + finishReason: 'stop', + isContinued: false, + logprobs: undefined, + messageId: 'msg-0', + request: {}, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + type: 'step-finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + warnings: undefined, + }, + { + experimental_providerMetadata: undefined, + finishReason: 'stop', + logprobs: undefined, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + type: 'finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + }, + ]); + }); + + it('should extract reasoning from single chunk with multiple tags', async () => { + const mockModel = new MockLanguageModelV1({ + async doStream() { + return { + stream: convertArrayToReadableStream([ + { + type: 'response-metadata', + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + { + type: 'text-delta', + textDelta: + 'analyzing the requestHere is the responsethinking about the responsemore', + }, + { + type: 'finish', + finishReason: 'stop', + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ]), + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = streamText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + experimental_generateMessageId: mockId({ prefix: 'msg' }), + }); + + expect( + await convertAsyncIterableToArray(result.fullStream), + ).toStrictEqual([ + { + messageId: 'msg-0', + request: {}, + type: 'step-start', + warnings: [], + }, + { + type: 'reasoning', + textDelta: 'analyzing the request', + }, + { + type: 'text-delta', + textDelta: 'Here is the response', + }, + { + type: 'reasoning', + textDelta: '\nthinking about the response', + }, + { + type: 'text-delta', + textDelta: '\nmore', + }, + { + experimental_providerMetadata: undefined, + finishReason: 'stop', + isContinued: false, + logprobs: undefined, + messageId: 'msg-0', + request: {}, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + type: 'step-finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + warnings: undefined, + }, + { + experimental_providerMetadata: undefined, + finishReason: 'stop', + logprobs: undefined, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + type: 'finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + }, + ]); + }); + }); +}); diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.ts b/packages/ai/core/middleware/extract-reasoning-middleware.ts new file mode 100644 index 000000000000..3213ded930c5 --- /dev/null +++ b/packages/ai/core/middleware/extract-reasoning-middleware.ts @@ -0,0 +1,135 @@ +import { LanguageModelV1StreamPart } from '@ai-sdk/provider'; +import { getPotentialStartIndex } from '../util/get-potential-start-index'; +import { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware'; + +/** + * Extract an XML-tagged reasoning section from the generated text and exposes it + * as a `reasoning` property on the result. + * + * @param tagName - The name of the XML tag to extract reasoning from. + * @param separator - The separator to use between reasoning and text sections. + */ +export function extractReasoningMiddleware({ + tagName, + separator = '\n', +}: { + tagName: string; + separator?: string; +}): Experimental_LanguageModelV1Middleware { + const openingTag = `<${tagName}>`; + const closingTag = `<\/${tagName}>`; + + return { + wrapGenerate: async ({ doGenerate }) => { + const { text, ...rest } = await doGenerate(); + + if (text == null) { + return { text, ...rest }; + } + + const regexp = new RegExp(`${openingTag}(.*?)${closingTag}`, 'gs'); + const matches = Array.from(text.matchAll(regexp)); + + if (!matches.length) { + return { text, ...rest }; + } + + const reasoning = matches.map(match => match[1]).join(separator); + + let textWithoutReasoning = text; + for (let i = matches.length - 1; i >= 0; i--) { + const match = matches[i]; + + const beforeMatch = textWithoutReasoning.slice(0, match.index); + const afterMatch = textWithoutReasoning.slice( + match.index! + match[0].length, + ); + + textWithoutReasoning = + beforeMatch + + (beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') + + afterMatch; + } + + return { text: textWithoutReasoning, reasoning, ...rest }; + }, + + wrapStream: async ({ doStream }) => { + const { stream, ...rest } = await doStream(); + + let isFirstReasoning = true; + let isFirstText = true; + let afterSwitch = false; + let isReasoning: boolean = false; + let buffer = ''; + + return { + stream: stream.pipeThrough( + new TransformStream< + LanguageModelV1StreamPart, + LanguageModelV1StreamPart + >({ + transform: (chunk, controller) => { + if (chunk.type !== 'text-delta') { + controller.enqueue(chunk); + return; + } + + buffer += chunk.textDelta; + + function publish(text: string) { + if (text.length > 0) { + const prefix = + afterSwitch && + (isReasoning ? !isFirstReasoning : !isFirstText) + ? separator + : ''; + + controller.enqueue({ + type: isReasoning ? 'reasoning' : 'text-delta', + textDelta: prefix + text, + }); + afterSwitch = false; + + if (isReasoning) { + isFirstReasoning = false; + } else { + isFirstText = false; + } + } + } + + do { + const nextTag = isReasoning ? closingTag : openingTag; + const startIndex = getPotentialStartIndex(buffer, nextTag); + + // no opening or closing tag found, publish the buffer + if (startIndex == null) { + publish(buffer); + buffer = ''; + break; + } + + // publish text before the tag + publish(buffer.slice(0, startIndex)); + + const foundFullMatch = + startIndex + nextTag.length <= buffer.length; + + if (foundFullMatch) { + buffer = buffer.slice(startIndex + nextTag.length); + isReasoning = !isReasoning; + afterSwitch = true; + } else { + buffer = buffer.slice(startIndex); + break; + } + } while (true); + }, + }), + ), + ...rest, + }; + }, + }; +} diff --git a/packages/ai/core/middleware/index.ts b/packages/ai/core/middleware/index.ts index cd39ed2f857d..14eec035716a 100644 --- a/packages/ai/core/middleware/index.ts +++ b/packages/ai/core/middleware/index.ts @@ -1,2 +1,3 @@ export type { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware'; export { experimental_wrapLanguageModel } from './wrap-language-model'; +export { extractReasoningMiddleware } from './extract-reasoning-middleware'; diff --git a/packages/ai/core/util/get-potential-start-index.test.ts b/packages/ai/core/util/get-potential-start-index.test.ts new file mode 100644 index 000000000000..366103de0288 --- /dev/null +++ b/packages/ai/core/util/get-potential-start-index.test.ts @@ -0,0 +1,33 @@ +import { getPotentialStartIndex } from './get-potential-start-index'; + +describe('getPotentialStartIndex', () => { + it('should return null when searchedText is empty', () => { + const result = getPotentialStartIndex('1234567890', ''); + expect(result).toBeNull(); + }); + + it('should return null when searchedText is not in text', () => { + const result = getPotentialStartIndex('1234567890', 'a'); + expect(result).toBeNull(); + }); + + it('should return index when searchedText is in text', () => { + const result = getPotentialStartIndex('1234567890', '1234567890'); + expect(result).toBe(0); + }); + + it('should return index when searchedText might start in text', () => { + const result = getPotentialStartIndex('1234567890', '0123'); + expect(result).toBe(9); + }); + + it('should return index when searchedText might start in text', () => { + const result = getPotentialStartIndex('1234567890', '90123'); + expect(result).toBe(8); + }); + + it('should return index when searchedText might start in text', () => { + const result = getPotentialStartIndex('1234567890', '890123'); + expect(result).toBe(7); + }); +}); diff --git a/packages/ai/core/util/get-potential-start-index.ts b/packages/ai/core/util/get-potential-start-index.ts new file mode 100644 index 000000000000..8971daeac3bd --- /dev/null +++ b/packages/ai/core/util/get-potential-start-index.ts @@ -0,0 +1,30 @@ +/** + * Returns the index of the start of the searchedText in the text, or null if it + * is not found. + */ +export function getPotentialStartIndex( + text: string, + searchedText: string, +): number | null { + // Return null immediately if searchedText is empty. + if (searchedText.length === 0) { + return null; + } + + // Check if the searchedText exists as a direct substring of text. + const directIndex = text.indexOf(searchedText); + if (directIndex !== -1) { + return directIndex; + } + + // Otherwise, look for the largest suffix of "text" that matches + // a prefix of "searchedText". We go from the end of text inward. + for (let i = text.length - 1; i >= 0; i--) { + const suffix = text.substring(i); + if (searchedText.startsWith(suffix)) { + return i; + } + } + + return null; +}