From f69fc1a801aec5853eb58f482f77053a485b7bb4 Mon Sep 17 00:00:00 2001 From: GregoireDaure <118850966+GregoireDaure@users.noreply.github.com> Date: Wed, 29 Jan 2025 11:03:16 +0100 Subject: [PATCH] feat: Add usage for Bedrock Anthropic Claude models (#157) * Add usage tracking to AnthropicClaude3ChatModel Introduce a `usage` variable to track usage information throughout the `AnthropicClaude3ChatModel` class. Add a `GetUsage` method to extract usage data from JSON nodes. Initialize `usage` to `null` and update it with data from JSON responses, ensuring it is non-null before use with a fallback to `Usage.Empty`. Update `usage` incrementally in streaming responses and include it in `ChatResponse`, passing it to `AddUsage` and `provider.AddUsage` methods. Remove `CONTENT_BLOCK_STOP` type handling from streaming response processing. * Improve null safety in usage variable assignment Modified the assignment of the `usage` variable to use `GetUsage(response?["usage"])` instead of `GetUsage(response)`. This change ensures that the `usage` variable is assigned from the "usage" field of the `response` object, if it exists, preventing potential null reference exceptions. --- .../src/Chat/AnthropicClaude3ChatModel.cs | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs b/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs index 05e6396..de7664a 100644 --- a/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs +++ b/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs @@ -39,6 +39,8 @@ public override async IAsyncEnumerable GenerateAsync( modelSettings: Settings, providerSettings: provider.ChatSettings); + Usage? usage = null; + var bodyJson = CreateBodyJson(prompt, usedSettings, request.Image); if (usedSettings.UseStreaming == true) @@ -51,7 +53,15 @@ public override async IAsyncEnumerable GenerateAsync( var streamEvent = (PayloadPart)payloadPart; var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) .ConfigureAwait(false); + + usage ??= GetUsage(chunk?["message"]?["usage"]); var type = chunk?["type"]!.GetValue().ToUpperInvariant(); + + if (type == "MESSAGE_DELTA") + { + usage += GetUsage(chunk?["usage"]); + } + if (type == "CONTENT_BLOCK_DELTA") { var delta = chunk?["delta"]?["text"]!.GetValue(); @@ -62,10 +72,6 @@ public override async IAsyncEnumerable GenerateAsync( }); stringBuilder.Append(delta); } - if (type == "CONTENT_BLOCK_STOP") - { - break; - } } OnDeltaReceived(new ChatResponseDelta @@ -82,24 +88,27 @@ public override async IAsyncEnumerable GenerateAsync( else { var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken).ConfigureAwait(false); + usage = GetUsage(response?["usage"]); var generatedText = response?["content"]?[0]?["text"]?.GetValue() ?? ""; messages.Add(generatedText.AsAiMessage()); } - var usage = Usage.Empty with + usage ??= Usage.Empty; + usage = usage.Value with { Time = watch.Elapsed, + Messages = messages.Count, }; - AddUsage(usage); - provider.AddUsage(usage); + AddUsage(usage.Value); + provider.AddUsage(usage.Value); var chatResponse = new ChatResponse { Messages = messages, UsedSettings = usedSettings, - Usage = usage, + Usage = usage.Value, }; OnResponseReceived(chatResponse); @@ -162,4 +171,24 @@ private static JsonObject CreateBodyJson( return bodyJson; } + + /// + /// Extracts usage information from the provided JSON node. + /// + /// The JSON node containing usage information. + /// A object with the extracted usage data. + private Usage GetUsage(JsonNode? usageNode) + { + var inputTokens = usageNode?["input_tokens"]?.GetValue() ?? 0; + var outputTokens = usageNode?["output_tokens"]?.GetValue() ?? 0; + var priceInUsd = 0.0; + + return Usage.Empty with + { + InputTokens = inputTokens, + OutputTokens = outputTokens, + Messages = 0, + PriceInUsd = priceInUsd, + }; + } } \ No newline at end of file