diff --git a/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs b/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs index 05e6396..de7664a 100644 --- a/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs +++ b/src/Amazon.Bedrock/src/Chat/AnthropicClaude3ChatModel.cs @@ -39,6 +39,8 @@ public override async IAsyncEnumerable GenerateAsync( modelSettings: Settings, providerSettings: provider.ChatSettings); + Usage? usage = null; + var bodyJson = CreateBodyJson(prompt, usedSettings, request.Image); if (usedSettings.UseStreaming == true) @@ -51,7 +53,15 @@ public override async IAsyncEnumerable GenerateAsync( var streamEvent = (PayloadPart)payloadPart; var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) .ConfigureAwait(false); + + usage ??= GetUsage(chunk?["message"]?["usage"]); var type = chunk?["type"]!.GetValue().ToUpperInvariant(); + + if (type == "MESSAGE_DELTA") + { + usage += GetUsage(chunk?["usage"]); + } + if (type == "CONTENT_BLOCK_DELTA") { var delta = chunk?["delta"]?["text"]!.GetValue(); @@ -62,10 +72,6 @@ public override async IAsyncEnumerable GenerateAsync( }); stringBuilder.Append(delta); } - if (type == "CONTENT_BLOCK_STOP") - { - break; - } } OnDeltaReceived(new ChatResponseDelta @@ -82,24 +88,27 @@ public override async IAsyncEnumerable GenerateAsync( else { var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken).ConfigureAwait(false); + usage = GetUsage(response?["usage"]); var generatedText = response?["content"]?[0]?["text"]?.GetValue() ?? ""; messages.Add(generatedText.AsAiMessage()); } - var usage = Usage.Empty with + usage ??= Usage.Empty; + usage = usage.Value with { Time = watch.Elapsed, + Messages = messages.Count, }; - AddUsage(usage); - provider.AddUsage(usage); + AddUsage(usage.Value); + provider.AddUsage(usage.Value); var chatResponse = new ChatResponse { Messages = messages, UsedSettings = usedSettings, - Usage = usage, + Usage = usage.Value, }; OnResponseReceived(chatResponse); @@ -162,4 +171,24 @@ private static JsonObject CreateBodyJson( return bodyJson; } + + /// + /// Extracts usage information from the provided JSON node. + /// + /// The JSON node containing usage information. + /// A object with the extracted usage data. + private Usage GetUsage(JsonNode? usageNode) + { + var inputTokens = usageNode?["input_tokens"]?.GetValue() ?? 0; + var outputTokens = usageNode?["output_tokens"]?.GetValue() ?? 0; + var priceInUsd = 0.0; + + return Usage.Empty with + { + InputTokens = inputTokens, + OutputTokens = outputTokens, + Messages = 0, + PriceInUsd = priceInUsd, + }; + } } \ No newline at end of file