From 4aa310ae462e2e9f64d312ae68781a2c4c1bc803 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 12:41:26 +0100 Subject: [PATCH 01/13] 1 --- .../extract-reasoning-middleware.test.ts | 33 +++++++++++++++++++ .../extract-reasoning-middleware.ts | 28 ++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 packages/ai/core/middleware/extract-reasoning-middleware.test.ts create mode 100644 packages/ai/core/middleware/extract-reasoning-middleware.ts diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts new file mode 100644 index 000000000000..63768e456b17 --- /dev/null +++ b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts @@ -0,0 +1,33 @@ +import { generateText } from '../generate-text'; +import { experimental_wrapLanguageModel } from '../middleware/wrap-language-model'; +import { MockLanguageModelV1 } from '../test/mock-language-model-v1'; +import { extractReasoningMiddleware } from './extract-reasoning-middleware'; + +describe('extractReasoningMiddleware', () => { + it('should extract reasoning from tags during generation', async () => { + const mockModel = new MockLanguageModelV1({ + async doGenerate() { + return { + text: 'analyzing the requestHere is the response', + finishReason: 'stop', + usage: { promptTokens: 10, completionTokens: 10 }, + rawCall: { + rawPrompt: 'Hello, how can I help?', + rawSettings: {}, + }, + }; + }, + }); + + const result = await generateText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + }); + + expect(result.text).toStrictEqual('Here is the response'); + expect(result.reasoning).toStrictEqual('analyzing the request'); + }); +}); diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.ts b/packages/ai/core/middleware/extract-reasoning-middleware.ts new file mode 100644 index 000000000000..24df0d8cd95c --- /dev/null +++ b/packages/ai/core/middleware/extract-reasoning-middleware.ts @@ -0,0 +1,28 @@ +import { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware'; + +export function extractReasoningMiddleware({ + tagName, +}: { + tagName: string; +}): Experimental_LanguageModelV1Middleware { + return { + wrapGenerate: async ({ doGenerate, params, model }) => { + const result = await doGenerate(); + + if (result.text == null) { + return result; + } + + const regexp = new RegExp(`<${tagName}>(.*?)<\/${tagName}>`, 's'); + + const match = result.text.match(regexp); + + if (match) { + result.reasoning = match[1]; + result.text = result.text.replace(match[0], '').trim(); + } + + return result; + }, + }; +} From b9287fcaecca841cb9fe6bca964cf4012e48960f Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 12:45:28 +0100 Subject: [PATCH 02/13] 2 --- .../extract-reasoning-middleware.test.ts | 31 ++++++++++++++++++- .../extract-reasoning-middleware.ts | 17 +++++++--- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts index 63768e456b17..0bd2bebc7e25 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts @@ -27,7 +27,36 @@ describe('extractReasoningMiddleware', () => { prompt: 'Hello, how can I help?', }); - expect(result.text).toStrictEqual('Here is the response'); expect(result.reasoning).toStrictEqual('analyzing the request'); + expect(result.text).toStrictEqual('Here is the response'); + }); + + it('should extract multiple reasoning from tags during generation', async () => { + const mockModel = new MockLanguageModelV1({ + async doGenerate() { + return { + text: 'analyzing the requestHere is the responsethinking about the responsemore', + finishReason: 'stop', + usage: { promptTokens: 10, completionTokens: 10 }, + rawCall: { + rawPrompt: 'Hello, how can I help?', + rawSettings: {}, + }, + }; + }, + }); + + const result = await generateText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + }); + + expect(result.reasoning).toStrictEqual( + 'analyzing the request\nthinking about the response', + ); + expect(result.text).toStrictEqual('Here is the response\nmore'); }); }); diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.ts b/packages/ai/core/middleware/extract-reasoning-middleware.ts index 24df0d8cd95c..0f5970c40158 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.ts @@ -2,8 +2,10 @@ import { Experimental_LanguageModelV1Middleware } from './language-model-v1-midd export function extractReasoningMiddleware({ tagName, + separator = '\n', }: { tagName: string; + separator?: string; }): Experimental_LanguageModelV1Middleware { return { wrapGenerate: async ({ doGenerate, params, model }) => { @@ -13,13 +15,18 @@ export function extractReasoningMiddleware({ return result; } - const regexp = new RegExp(`<${tagName}>(.*?)<\/${tagName}>`, 's'); + const regexp = new RegExp(`<${tagName}>(.*?)<\/${tagName}>`, 'gs'); + const matches = Array.from(result.text.matchAll(regexp)); - const match = result.text.match(regexp); + if (matches.length > 0) { + // Combine all reasoning parts with the specified separator + result.reasoning = matches.map(match => match[1]).join(separator); - if (match) { - result.reasoning = match[1]; - result.text = result.text.replace(match[0], '').trim(); + // Remove all reasoning tags from the text and join remaining parts with separator + const parts = result.text + .split(new RegExp(`<${tagName}>.*?<\/${tagName}>`, 'gs')) + .filter(part => part.trim().length > 0); + result.text = parts.join(separator).trim(); } return result; From d6debc533e75d7fc62bce0150141b629c5e75178 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 12:54:28 +0100 Subject: [PATCH 03/13] 3 --- .../extract-reasoning-middleware.ts | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.ts b/packages/ai/core/middleware/extract-reasoning-middleware.ts index 0f5970c40158..85410828a00b 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.ts @@ -8,7 +8,7 @@ export function extractReasoningMiddleware({ separator?: string; }): Experimental_LanguageModelV1Middleware { return { - wrapGenerate: async ({ doGenerate, params, model }) => { + wrapGenerate: async ({ doGenerate }) => { const result = await doGenerate(); if (result.text == null) { @@ -18,17 +18,29 @@ export function extractReasoningMiddleware({ const regexp = new RegExp(`<${tagName}>(.*?)<\/${tagName}>`, 'gs'); const matches = Array.from(result.text.matchAll(regexp)); - if (matches.length > 0) { - // Combine all reasoning parts with the specified separator - result.reasoning = matches.map(match => match[1]).join(separator); + if (!matches.length) { + return result; + } + + result.reasoning = matches.map(match => match[1]).join(separator); - // Remove all reasoning tags from the text and join remaining parts with separator - const parts = result.text - .split(new RegExp(`<${tagName}>.*?<\/${tagName}>`, 'gs')) - .filter(part => part.trim().length > 0); - result.text = parts.join(separator).trim(); + let textWithoutReasoning = result.text; + for (let i = matches.length - 1; i >= 0; i--) { + const match = matches[i]; + + const beforeMatch = textWithoutReasoning.slice(0, match.index); + const afterMatch = textWithoutReasoning.slice( + match.index! + match[0].length, + ); + + textWithoutReasoning = + beforeMatch + + (beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') + + afterMatch; } + result.text = textWithoutReasoning; + return result; }, }; From b0cc9c4e1b72ceaa16d09cc464b525b186789320 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 13:25:51 +0100 Subject: [PATCH 04/13] 4 --- .../util/get-potential-start-index.test.ts | 33 +++++++++++++++++++ .../ai/core/util/get-potential-start-index.ts | 30 +++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 packages/ai/core/util/get-potential-start-index.test.ts create mode 100644 packages/ai/core/util/get-potential-start-index.ts diff --git a/packages/ai/core/util/get-potential-start-index.test.ts b/packages/ai/core/util/get-potential-start-index.test.ts new file mode 100644 index 000000000000..366103de0288 --- /dev/null +++ b/packages/ai/core/util/get-potential-start-index.test.ts @@ -0,0 +1,33 @@ +import { getPotentialStartIndex } from './get-potential-start-index'; + +describe('getPotentialStartIndex', () => { + it('should return null when searchedText is empty', () => { + const result = getPotentialStartIndex('1234567890', ''); + expect(result).toBeNull(); + }); + + it('should return null when searchedText is not in text', () => { + const result = getPotentialStartIndex('1234567890', 'a'); + expect(result).toBeNull(); + }); + + it('should return index when searchedText is in text', () => { + const result = getPotentialStartIndex('1234567890', '1234567890'); + expect(result).toBe(0); + }); + + it('should return index when searchedText might start in text', () => { + const result = getPotentialStartIndex('1234567890', '0123'); + expect(result).toBe(9); + }); + + it('should return index when searchedText might start in text', () => { + const result = getPotentialStartIndex('1234567890', '90123'); + expect(result).toBe(8); + }); + + it('should return index when searchedText might start in text', () => { + const result = getPotentialStartIndex('1234567890', '890123'); + expect(result).toBe(7); + }); +}); diff --git a/packages/ai/core/util/get-potential-start-index.ts b/packages/ai/core/util/get-potential-start-index.ts new file mode 100644 index 000000000000..8971daeac3bd --- /dev/null +++ b/packages/ai/core/util/get-potential-start-index.ts @@ -0,0 +1,30 @@ +/** + * Returns the index of the start of the searchedText in the text, or null if it + * is not found. + */ +export function getPotentialStartIndex( + text: string, + searchedText: string, +): number | null { + // Return null immediately if searchedText is empty. + if (searchedText.length === 0) { + return null; + } + + // Check if the searchedText exists as a direct substring of text. + const directIndex = text.indexOf(searchedText); + if (directIndex !== -1) { + return directIndex; + } + + // Otherwise, look for the largest suffix of "text" that matches + // a prefix of "searchedText". We go from the end of text inward. + for (let i = text.length - 1; i >= 0; i--) { + const suffix = text.substring(i); + if (searchedText.startsWith(suffix)) { + return i; + } + } + + return null; +} From f5959d0c846fb437429f5043fb8b66f7ec2d5402 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 13:42:28 +0100 Subject: [PATCH 05/13] 5 --- .../extract-reasoning-middleware.test.ts | 201 ++++++++++++++---- .../extract-reasoning-middleware.ts | 92 +++++++- 2 files changed, 236 insertions(+), 57 deletions(-) diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts index 0bd2bebc7e25..2496a6178410 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts @@ -1,62 +1,169 @@ -import { generateText } from '../generate-text'; +import { + convertArrayToReadableStream, + convertAsyncIterableToArray, +} from '@ai-sdk/provider-utils/test'; +import { generateText, streamText } from '../generate-text'; import { experimental_wrapLanguageModel } from '../middleware/wrap-language-model'; import { MockLanguageModelV1 } from '../test/mock-language-model-v1'; import { extractReasoningMiddleware } from './extract-reasoning-middleware'; +import { mockId } from '../test/mock-id'; describe('extractReasoningMiddleware', () => { - it('should extract reasoning from tags during generation', async () => { - const mockModel = new MockLanguageModelV1({ - async doGenerate() { - return { - text: 'analyzing the requestHere is the response', - finishReason: 'stop', - usage: { promptTokens: 10, completionTokens: 10 }, - rawCall: { - rawPrompt: 'Hello, how can I help?', - rawSettings: {}, - }, - }; - }, - }); + describe('wrapGenerate', () => { + it('should extract reasoning from tags during generation', async () => { + const mockModel = new MockLanguageModelV1({ + async doGenerate() { + return { + text: 'analyzing the requestHere is the response', + finishReason: 'stop', + usage: { promptTokens: 10, completionTokens: 10 }, + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = await generateText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + }); - const result = await generateText({ - model: experimental_wrapLanguageModel({ - model: mockModel, - middleware: extractReasoningMiddleware({ tagName: 'think' }), - }), - prompt: 'Hello, how can I help?', + expect(result.reasoning).toStrictEqual('analyzing the request'); + expect(result.text).toStrictEqual('Here is the response'); }); - expect(result.reasoning).toStrictEqual('analyzing the request'); - expect(result.text).toStrictEqual('Here is the response'); + it('should extract multiple reasoning from tags during generation', async () => { + const mockModel = new MockLanguageModelV1({ + async doGenerate() { + return { + text: 'analyzing the requestHere is the responsethinking about the responsemore', + finishReason: 'stop', + usage: { promptTokens: 10, completionTokens: 10 }, + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = await generateText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + }); + + expect(result.reasoning).toStrictEqual( + 'analyzing the request\nthinking about the response', + ); + expect(result.text).toStrictEqual('Here is the response\nmore'); + }); }); - it('should extract multiple reasoning from tags during generation', async () => { - const mockModel = new MockLanguageModelV1({ - async doGenerate() { - return { - text: 'analyzing the requestHere is the responsethinking about the responsemore', + describe('wrapStream', () => { + it('should extract reasoning from tags during streaming', async () => { + const mockModel = new MockLanguageModelV1({ + async doStream() { + return { + stream: convertArrayToReadableStream([ + { + type: 'response-metadata', + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + { type: 'text-delta', textDelta: 'ana' }, + { type: 'text-delta', textDelta: 'lyzing the request' }, + { type: 'text-delta', textDelta: 'Here' }, + { type: 'text-delta', textDelta: ' is the response' }, + { + type: 'finish', + finishReason: 'stop', + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ]), + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = streamText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + experimental_generateMessageId: mockId({ prefix: 'msg' }), + }); + + expect( + await convertAsyncIterableToArray(result.fullStream), + ).toStrictEqual([ + { + messageId: 'msg-0', + request: {}, + type: 'step-start', + warnings: [], + }, + { + type: 'reasoning', + textDelta: 'ana', + }, + { + type: 'reasoning', + textDelta: 'lyzing the request', + }, + { + type: 'text-delta', + textDelta: 'Here', + }, + { + type: 'text-delta', + textDelta: ' is the response', + }, + { + experimental_providerMetadata: undefined, finishReason: 'stop', - usage: { promptTokens: 10, completionTokens: 10 }, - rawCall: { - rawPrompt: 'Hello, how can I help?', - rawSettings: {}, + isContinued: false, + logprobs: undefined, + messageId: 'msg-0', + request: {}, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), }, - }; - }, - }); - - const result = await generateText({ - model: experimental_wrapLanguageModel({ - model: mockModel, - middleware: extractReasoningMiddleware({ tagName: 'think' }), - }), - prompt: 'Hello, how can I help?', + type: 'step-finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + warnings: undefined, + }, + { + experimental_providerMetadata: undefined, + finishReason: 'stop', + logprobs: undefined, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + type: 'finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + }, + ]); }); - - expect(result.reasoning).toStrictEqual( - 'analyzing the request\nthinking about the response', - ); - expect(result.text).toStrictEqual('Here is the response\nmore'); }); }); diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.ts b/packages/ai/core/middleware/extract-reasoning-middleware.ts index 85410828a00b..d2403945dbca 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.ts @@ -1,5 +1,15 @@ +import { LanguageModelV1StreamPart } from '@ai-sdk/provider'; import { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware'; +import { getPotentialStartIndex } from '../util/get-potential-start-index'; +import { read } from 'fs'; +/** + * Extract an XML-tagged reasoning section from the generated text and exposes it + * as a `reasoning` property on the result. + * + * @param tagName - The name of the XML tag to extract reasoning from. + * @param separator - The separator to use between reasoning and text sections. + */ export function extractReasoningMiddleware({ tagName, separator = '\n', @@ -7,24 +17,27 @@ export function extractReasoningMiddleware({ tagName: string; separator?: string; }): Experimental_LanguageModelV1Middleware { + const openingTag = `<${tagName}>`; + const closingTag = `<\/${tagName}>`; + return { wrapGenerate: async ({ doGenerate }) => { - const result = await doGenerate(); + const { text, ...rest } = await doGenerate(); - if (result.text == null) { - return result; + if (text == null) { + return { text, ...rest }; } - const regexp = new RegExp(`<${tagName}>(.*?)<\/${tagName}>`, 'gs'); - const matches = Array.from(result.text.matchAll(regexp)); + const regexp = new RegExp(`${openingTag}(.*?)${closingTag}`, 'gs'); + const matches = Array.from(text.matchAll(regexp)); if (!matches.length) { - return result; + return { text, ...rest }; } - result.reasoning = matches.map(match => match[1]).join(separator); + const reasoning = matches.map(match => match[1]).join(separator); - let textWithoutReasoning = result.text; + let textWithoutReasoning = text; for (let i = matches.length - 1; i >= 0; i--) { const match = matches[i]; @@ -39,9 +52,68 @@ export function extractReasoningMiddleware({ afterMatch; } - result.text = textWithoutReasoning; + return { text: textWithoutReasoning, reasoning, ...rest }; + }, + + wrapStream: async ({ doStream }) => { + const { stream, ...rest } = await doStream(); + + let isReasoning: boolean = false; + let buffer = ''; + + return { + stream: stream.pipeThrough( + new TransformStream< + LanguageModelV1StreamPart, + LanguageModelV1StreamPart + >({ + transform: (chunk, controller) => { + if (chunk.type !== 'text-delta') { + controller.enqueue(chunk); + return; + } + + buffer += chunk.textDelta; + + function publish(text: string) { + if (text.length > 0) { + controller.enqueue({ + type: isReasoning ? 'reasoning' : 'text-delta', + textDelta: text, + }); + } + } + + do { + const nextTag = isReasoning ? closingTag : openingTag; + const startIndex = getPotentialStartIndex(buffer, nextTag); + + // no opening or closing tag found, publish the buffer + if (startIndex == null) { + publish(buffer); + buffer = ''; + break; + } + + // publish text before the tag + publish(buffer.slice(0, startIndex)); + + const foundFullMatch = + startIndex + nextTag.length <= buffer.length; - return result; + if (foundFullMatch) { + buffer = buffer.slice(startIndex + nextTag.length); + isReasoning = !isReasoning; + } else { + buffer = buffer.slice(startIndex); + break; + } + } while (true); + }, + }), + ), + ...rest, + }; }, }; } From 47b0cebee2610f612f01ea2c4fb5d183fca9c5a3 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 13:43:23 +0100 Subject: [PATCH 06/13] 6 --- .../ai/core/middleware/extract-reasoning-middleware.test.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts index 2496a6178410..1884a6a17655 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts @@ -10,7 +10,7 @@ import { mockId } from '../test/mock-id'; describe('extractReasoningMiddleware', () => { describe('wrapGenerate', () => { - it('should extract reasoning from tags during generation', async () => { + it('should extract reasoning from tags', async () => { const mockModel = new MockLanguageModelV1({ async doGenerate() { return { @@ -34,7 +34,7 @@ describe('extractReasoningMiddleware', () => { expect(result.text).toStrictEqual('Here is the response'); }); - it('should extract multiple reasoning from tags during generation', async () => { + it('should extract reasoning from multiple tags', async () => { const mockModel = new MockLanguageModelV1({ async doGenerate() { return { @@ -62,7 +62,7 @@ describe('extractReasoningMiddleware', () => { }); describe('wrapStream', () => { - it('should extract reasoning from tags during streaming', async () => { + it('should extract reasoning from split tags', async () => { const mockModel = new MockLanguageModelV1({ async doStream() { return { From 0b7d6e9abf57afd87a5ca5c7bc1098e8c46f695c Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 13:49:52 +0100 Subject: [PATCH 07/13] 7 --- .../extract-reasoning-middleware.test.ts | 103 ++++++++++++++++++ .../extract-reasoning-middleware.ts | 19 +++- 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts index 1884a6a17655..0f91e2e130b9 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts @@ -165,5 +165,108 @@ describe('extractReasoningMiddleware', () => { }, ]); }); + + it('should extract reasoning from single chunk with multiple tags', async () => { + const mockModel = new MockLanguageModelV1({ + async doStream() { + return { + stream: convertArrayToReadableStream([ + { + type: 'response-metadata', + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + { + type: 'text-delta', + textDelta: + 'analyzing the requestHere is the responsethinking about the responsemore', + }, + { + type: 'finish', + finishReason: 'stop', + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ]), + rawCall: { rawPrompt: '', rawSettings: {} }, + }; + }, + }); + + const result = streamText({ + model: experimental_wrapLanguageModel({ + model: mockModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Hello, how can I help?', + experimental_generateMessageId: mockId({ prefix: 'msg' }), + }); + + expect( + await convertAsyncIterableToArray(result.fullStream), + ).toStrictEqual([ + { + messageId: 'msg-0', + request: {}, + type: 'step-start', + warnings: [], + }, + { + type: 'reasoning', + textDelta: 'analyzing the request', + }, + { + type: 'text-delta', + textDelta: 'Here is the response', + }, + { + type: 'reasoning', + textDelta: '\nthinking about the response', + }, + { + type: 'text-delta', + textDelta: '\nmore', + }, + { + experimental_providerMetadata: undefined, + finishReason: 'stop', + isContinued: false, + logprobs: undefined, + messageId: 'msg-0', + request: {}, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + type: 'step-finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + warnings: undefined, + }, + { + experimental_providerMetadata: undefined, + finishReason: 'stop', + logprobs: undefined, + response: { + headers: undefined, + id: 'id-0', + modelId: 'mock-model-id', + timestamp: new Date(0), + }, + type: 'finish', + usage: { + completionTokens: 10, + promptTokens: 3, + totalTokens: 13, + }, + }, + ]); + }); }); }); diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.ts b/packages/ai/core/middleware/extract-reasoning-middleware.ts index d2403945dbca..d2261c23f21a 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.ts @@ -58,6 +58,9 @@ export function extractReasoningMiddleware({ wrapStream: async ({ doStream }) => { const { stream, ...rest } = await doStream(); + let isFirstReasoning = true; + let isFirstText = true; + let afterSwitch = false; let isReasoning: boolean = false; let buffer = ''; @@ -77,10 +80,23 @@ export function extractReasoningMiddleware({ function publish(text: string) { if (text.length > 0) { + const prefix = + afterSwitch && + (isReasoning ? !isFirstReasoning : !isFirstText) + ? separator + : ''; + controller.enqueue({ type: isReasoning ? 'reasoning' : 'text-delta', - textDelta: text, + textDelta: prefix + text, }); + afterSwitch = false; + + if (isReasoning) { + isFirstReasoning = false; + } else { + isFirstText = false; + } } } @@ -104,6 +120,7 @@ export function extractReasoningMiddleware({ if (foundFullMatch) { buffer = buffer.slice(startIndex + nextTag.length); isReasoning = !isReasoning; + afterSwitch = true; } else { buffer = buffer.slice(startIndex); break; From 0a0b0a54daae5c183e4d984579e38b1e623aa0a0 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 13:50:09 +0100 Subject: [PATCH 08/13] clean --- .../ai/core/middleware/extract-reasoning-middleware.test.ts | 2 +- packages/ai/core/middleware/extract-reasoning-middleware.ts | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts index 0f91e2e130b9..1fbfddee26c0 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.test.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.test.ts @@ -4,9 +4,9 @@ import { } from '@ai-sdk/provider-utils/test'; import { generateText, streamText } from '../generate-text'; import { experimental_wrapLanguageModel } from '../middleware/wrap-language-model'; +import { mockId } from '../test/mock-id'; import { MockLanguageModelV1 } from '../test/mock-language-model-v1'; import { extractReasoningMiddleware } from './extract-reasoning-middleware'; -import { mockId } from '../test/mock-id'; describe('extractReasoningMiddleware', () => { describe('wrapGenerate', () => { diff --git a/packages/ai/core/middleware/extract-reasoning-middleware.ts b/packages/ai/core/middleware/extract-reasoning-middleware.ts index d2261c23f21a..3213ded930c5 100644 --- a/packages/ai/core/middleware/extract-reasoning-middleware.ts +++ b/packages/ai/core/middleware/extract-reasoning-middleware.ts @@ -1,7 +1,6 @@ import { LanguageModelV1StreamPart } from '@ai-sdk/provider'; -import { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware'; import { getPotentialStartIndex } from '../util/get-potential-start-index'; -import { read } from 'fs'; +import { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware'; /** * Extract an XML-tagged reasoning section from the generated text and exposes it From 74700f48783b59b12a559567902b75da62495e09 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 14:45:16 +0100 Subject: [PATCH 09/13] example --- ...am.ts => deepseek-reasoning-fullstream.ts} | 0 .../stream-text/groq-reasoning-fullstream.ts | 45 +++++++++++++++++++ packages/ai/core/middleware/index.ts | 1 + 3 files changed, 46 insertions(+) rename examples/ai-core/src/stream-text/{deepseek-fullstream.ts => deepseek-reasoning-fullstream.ts} (100%) create mode 100644 examples/ai-core/src/stream-text/groq-reasoning-fullstream.ts diff --git a/examples/ai-core/src/stream-text/deepseek-fullstream.ts b/examples/ai-core/src/stream-text/deepseek-reasoning-fullstream.ts similarity index 100% rename from examples/ai-core/src/stream-text/deepseek-fullstream.ts rename to examples/ai-core/src/stream-text/deepseek-reasoning-fullstream.ts diff --git a/examples/ai-core/src/stream-text/groq-reasoning-fullstream.ts b/examples/ai-core/src/stream-text/groq-reasoning-fullstream.ts new file mode 100644 index 000000000000..ce3303984b54 --- /dev/null +++ b/examples/ai-core/src/stream-text/groq-reasoning-fullstream.ts @@ -0,0 +1,45 @@ +import { groq } from '@ai-sdk/groq'; +import { + experimental_wrapLanguageModel, + extractReasoningMiddleware, + streamText, +} from 'ai'; +import 'dotenv/config'; + +async function main() { + const result = streamText({ + model: experimental_wrapLanguageModel({ + model: groq('deepseek-r1-distill-llama-70b'), + middleware: extractReasoningMiddleware({ tagName: 'think' }), + }), + prompt: 'Invent a new holiday and describe its traditions.', + }); + + let enteredReasoning = false; + let enteredText = false; + for await (const part of result.fullStream) { + if (part.type === 'reasoning') { + if (!enteredReasoning) { + enteredReasoning = true; + console.log('\nSTREAMING REASONING:\n'); + } + process.stdout.write(part.textDelta); + } else if (part.type === 'text-delta') { + if (!enteredText) { + enteredText = true; + console.log('\nSTREAMING TEXT:\n'); + } + process.stdout.write(part.textDelta); + } + } + + console.log(); + console.log('\nFINAL REASONING:\n', await result.reasoning); + console.log('\nFINAL TEXT:\n', await result.text); + + console.log(); + console.log('Token usage:', await result.usage); + console.log('Finish reason:', await result.finishReason); +} + +main().catch(console.error); diff --git a/packages/ai/core/middleware/index.ts b/packages/ai/core/middleware/index.ts index cd39ed2f857d..14eec035716a 100644 --- a/packages/ai/core/middleware/index.ts +++ b/packages/ai/core/middleware/index.ts @@ -1,2 +1,3 @@ export type { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware'; export { experimental_wrapLanguageModel } from './wrap-language-model'; +export { extractReasoningMiddleware } from './extract-reasoning-middleware'; From 8a0234c48ab702606c3fb0aba7713578f92e610f Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 14:46:28 +0100 Subject: [PATCH 10/13] changset --- .changeset/itchy-pumpkins-punch.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/itchy-pumpkins-punch.md diff --git a/.changeset/itchy-pumpkins-punch.md b/.changeset/itchy-pumpkins-punch.md new file mode 100644 index 000000000000..489cbd958a3e --- /dev/null +++ b/.changeset/itchy-pumpkins-punch.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat (core): add extractReasoningMiddleware From 6a30a9f6088658c10e04fefe23c7646be84f2f52 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 14:52:00 +0100 Subject: [PATCH 11/13] docs --- content/docs/03-ai-sdk-core/45-middleware.mdx | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/content/docs/03-ai-sdk-core/45-middleware.mdx b/content/docs/03-ai-sdk-core/45-middleware.mdx index 187ed287da2b..5cae494168c3 100644 --- a/content/docs/03-ai-sdk-core/45-middleware.mdx +++ b/content/docs/03-ai-sdk-core/45-middleware.mdx @@ -40,6 +40,27 @@ const result = streamText({ }); ``` +## Built-in Middleware + +### Extract Reasoning + +Some providers and models expose reasoning information in the generated text using special tags, +e.g. <think> and </think>. + +The `extractReasoningMiddleware` function can be used to extract this reasoning information and expose it as a `reasoning` property on the result. + +```ts +import { + experimental_wrapLanguageModel as wrapLanguageModel, + extractReasoningMiddleware, +} from 'ai'; + +const wrappedLanguageModel = wrapLanguageModel({ + model: yourModel, + middleware: extractReasoningMiddleware({ tagName: 'think' }), +}); +``` + ## Implementing Language Model Middleware From e5cb8132179828136a9bb44790a8ebc0f76e1d6c Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 14:55:03 +0100 Subject: [PATCH 12/13] ref --- .../66-extract-reasoning-middleware.mdx | 61 +++++++++++++++++++ .../07-reference/01-ai-sdk-core/index.mdx | 11 ++++ 2 files changed, 72 insertions(+) create mode 100644 content/docs/07-reference/01-ai-sdk-core/66-extract-reasoning-middleware.mdx diff --git a/content/docs/07-reference/01-ai-sdk-core/66-extract-reasoning-middleware.mdx b/content/docs/07-reference/01-ai-sdk-core/66-extract-reasoning-middleware.mdx new file mode 100644 index 000000000000..1d9fdd5218b6 --- /dev/null +++ b/content/docs/07-reference/01-ai-sdk-core/66-extract-reasoning-middleware.mdx @@ -0,0 +1,61 @@ +--- +title: extractReasoningMiddleware +description: Middleware that extracts XML-tagged reasoning sections from generated text +--- + +# `extractReasoningMiddleware()` + +`extractReasoningMiddleware` is a middleware function that extracts XML-tagged reasoning sections from generated text and exposes them separately from the main text content. This is particularly useful when you want to separate an AI model's reasoning process from its final output. + +```ts +import { extractReasoningMiddleware } from 'ai'; + +const middleware = extractReasoningMiddleware({ + tagName: 'reasoning', + separator: '\n', +}); +``` + +## Import + + + +## API Signature + +### Parameters + + + +### Returns + +Returns a middleware object that: + +- Processes both streaming and non-streaming responses +- Extracts content between specified XML tags as reasoning +- Removes the XML tags and reasoning from the main text +- Adds a `reasoning` property to the result containing the extracted content +- Maintains proper separation between text sections using the specified separator + +### Type Parameters + +The middleware works with the `LanguageModelV1StreamPart` type for streaming responses. diff --git a/content/docs/07-reference/01-ai-sdk-core/index.mdx b/content/docs/07-reference/01-ai-sdk-core/index.mdx index 00ec54f5dd56..33e948f23e32 100644 --- a/content/docs/07-reference/01-ai-sdk-core/index.mdx +++ b/content/docs/07-reference/01-ai-sdk-core/index.mdx @@ -87,6 +87,17 @@ It also contains the following helper functions: 'Creates a ReadableStream that emits values with configurable delays.', href: '/docs/reference/ai-sdk-core/simulate-readable-stream', }, + { + title: 'wrapLanguageModel()', + description: 'Wraps a language model with middleware.', + href: '/docs/reference/ai-sdk-core/wrap-language-model', + }, + { + title: 'extractReasoningMiddleware()', + description: + 'Extracts reasoning from the generated text and exposes it as a `reasoning` property on the result.', + href: '/docs/reference/ai-sdk-core/extract-reasoning-middleware', + }, { title: 'smoothStream()', description: 'Smooths text streaming output.', From a19df8c40d9c734fba7f782d655f92db2c5e6710 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Mon, 27 Jan 2025 15:01:54 +0100 Subject: [PATCH 13/13] note --- content/docs/03-ai-sdk-core/45-middleware.mdx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/content/docs/03-ai-sdk-core/45-middleware.mdx b/content/docs/03-ai-sdk-core/45-middleware.mdx index 5cae494168c3..7fe88e056ab5 100644 --- a/content/docs/03-ai-sdk-core/45-middleware.mdx +++ b/content/docs/03-ai-sdk-core/45-middleware.mdx @@ -55,12 +55,14 @@ import { extractReasoningMiddleware, } from 'ai'; -const wrappedLanguageModel = wrapLanguageModel({ +const model = wrapLanguageModel({ model: yourModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }); ``` +You can then use that enhanced model in functions like `generateText` and `streamText`. + ## Implementing Language Model Middleware