From 04a9c7dbfa29038640421fd6a01cffb79db18d30 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 19 Aug 2024 18:07:51 -0700 Subject: [PATCH] groq[minor]: Fix streaming metadata back to client (#6573) * groq[minor]: Fix streaming metadata back to client * chore: lint files * chore: lint files * implemented usage metadata for invoke too --- libs/langchain-groq/src/chat_models.ts | 52 +++++++++++++++++-- .../tests/chat_models.standard.int.test.ts | 16 ------ 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index 823dd41c98f2..79d20b05a546 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -23,6 +23,7 @@ import { OpenAIToolCall, isAIMessage, BaseMessageChunk, + UsageMetadata, } from "@langchain/core/messages"; import { ChatGeneration, @@ -179,7 +180,8 @@ function convertMessagesToGroqParams( } function groqResponseToChatMessage( - message: ChatCompletionsAPI.ChatCompletionMessage + message: ChatCompletionsAPI.ChatCompletionMessage, + usageMetadata?: UsageMetadata ): BaseMessage { const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as | OpenAIToolCall[] @@ -201,6 +203,7 @@ function groqResponseToChatMessage( additional_kwargs: { tool_calls: rawToolCalls }, tool_calls: toolCalls, invalid_tool_calls: invalidToolCalls, + usage_metadata: usageMetadata, }); } default: @@ -226,7 +229,8 @@ function _convertDeltaToolCallToToolCallChunk( function _convertDeltaToMessageChunk( // eslint-disable-next-line @typescript-eslint/no-explicit-any delta: Record, - index: number + index: number, + xGroq?: ChatCompletionsAPI.ChatCompletionChunk.XGroq ): { message: BaseMessageChunk; toolCallData?: { @@ -250,6 +254,18 @@ function _convertDeltaToMessageChunk( } else { additional_kwargs = {}; } + + let usageMetadata: UsageMetadata | undefined; + let groqMessageId: string | undefined; + if (xGroq?.usage) { + usageMetadata = { + input_tokens: xGroq.usage.prompt_tokens, + output_tokens: xGroq.usage.completion_tokens, + total_tokens: xGroq.usage.total_tokens, + }; + groqMessageId = xGroq.id; + } + if (role === "user") { return { message: new HumanMessageChunk({ content }), @@ -270,6 +286,8 @@ function _convertDeltaToMessageChunk( index: tc.index, })) : undefined, + usage_metadata: usageMetadata, + id: groqMessageId, }), toolCallData: toolCallChunks ? toolCallChunks.map((tc) => ({ @@ -771,7 +789,10 @@ export class ChatGroq extends BaseChatModel< index: number; type: "tool_call_chunk"; }[] = []; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let responseMetadata: Record | undefined; for await (const data of response) { + responseMetadata = data; const choice = data?.choices[0]; if (!choice) { continue; @@ -787,7 +808,8 @@ export class ChatGroq extends BaseChatModel< ...choice.delta, role, } ?? {}, - choice.index + choice.index, + data.x_groq ); if (toolCallData) { @@ -818,6 +840,19 @@ export class ChatGroq extends BaseChatModel< void runManager?.handleLLMNewToken(chunk.text ?? ""); } + if (responseMetadata) { + if ("choices" in responseMetadata) { + delete responseMetadata.choices; + } + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: "", + response_metadata: responseMetadata, + }), + text: "", + }); + } + if (options.signal?.aborted) { throw new Error("AbortError"); } @@ -898,10 +933,19 @@ export class ChatGroq extends BaseChatModel< if ("choices" in data && data.choices) { for (const part of (data as ChatCompletion).choices) { const text = part.message?.content ?? ""; + let usageMetadata: UsageMetadata | undefined; + if (tokenUsage.totalTokens !== undefined) { + usageMetadata = { + input_tokens: tokenUsage.promptTokens ?? 0, + output_tokens: tokenUsage.completionTokens ?? 0, + total_tokens: tokenUsage.totalTokens, + }; + } const generation: ChatGeneration = { text, message: groqResponseToChatMessage( - part.message ?? { role: "assistant" } + part.message ?? { role: "assistant" }, + usageMetadata ), }; generation.generationInfo = { diff --git a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts index 1eb1384b26a7..41dc90a0fdc0 100644 --- a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts @@ -25,22 +25,6 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests< }); } - async testUsageMetadataStreaming() { - this.skipTestMessage( - "testUsageMetadataStreaming", - "ChatGroq", - "Streaming tokens is not currently supported." - ); - } - - async testUsageMetadata() { - this.skipTestMessage( - "testUsageMetadata", - "ChatGroq", - "Usage metadata tokens is not currently supported." - ); - } - async testToolMessageHistoriesListContent() { this.skipTestMessage( "testToolMessageHistoriesListContent",