From 44fe5d08d9f7fa33e4fe7816273fd271d98ebfe3 Mon Sep 17 00:00:00 2001 From: Rodrigo Lazo Date: Tue, 11 Feb 2025 15:41:25 -0500 Subject: [PATCH] [VertexAI] Add support for token-based usage metrics (#14406) --- FirebaseVertexAI/CHANGELOG.md | 1 + .../Sources/CountTokensRequest.swift | 3 + .../Sources/GenerateContentResponse.swift | 14 +++++ .../Sources/ModalityTokenCount.swift | 61 +++++++++++++++++++ FirebaseVertexAI/Sources/VertexLog.swift | 1 + .../Tests/Unit/GenerativeModelTests.swift | 43 +++++++++++++ 6 files changed, 123 insertions(+) create mode 100644 FirebaseVertexAI/Sources/ModalityTokenCount.swift diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index d8a5b49802e..ab9fb24015c 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -5,6 +5,7 @@ Note: This feature is in Public Preview, which means that the it is not subject to any SLA or deprecation policy and could change in backwards-incompatible ways. +- [feature] Added support for modality-based token count. (#14406) # 11.6.0 - [changed] The token counts from `GenerativeModel.countTokens(...)` now include diff --git a/FirebaseVertexAI/Sources/CountTokensRequest.swift b/FirebaseVertexAI/Sources/CountTokensRequest.swift index 6c36d96b4c0..f8d3fb2241e 100644 --- a/FirebaseVertexAI/Sources/CountTokensRequest.swift +++ b/FirebaseVertexAI/Sources/CountTokensRequest.swift @@ -46,6 +46,9 @@ public struct CountTokensResponse { /// > Important: This does not include billable image, video or other non-text input. See /// [Vertex AI pricing](https://cloud.google.com/vertex-ai/generative-ai/pricing) for details. public let totalBillableCharacters: Int? + + /// The breakdown, by modality, of how many tokens are consumed by the prompt. + public let promptTokensDetails: [ModalityTokenCount] } // MARK: - Codable Conformances diff --git a/FirebaseVertexAI/Sources/GenerateContentResponse.swift b/FirebaseVertexAI/Sources/GenerateContentResponse.swift index b7b4f1c536a..e81407d708a 100644 --- a/FirebaseVertexAI/Sources/GenerateContentResponse.swift +++ b/FirebaseVertexAI/Sources/GenerateContentResponse.swift @@ -28,6 +28,12 @@ public struct GenerateContentResponse: Sendable { /// The total number of tokens in both the request and response. public let totalTokenCount: Int + + /// The breakdown, by modality, of how many tokens are consumed by the prompt + public let promptTokensDetails: [ModalityTokenCount] + + /// The breakdown, by modality, of how many tokens are consumed by the candidates + public let candidatesTokensDetails: [ModalityTokenCount] } /// A list of candidate response content, ordered from best to worst. @@ -299,6 +305,8 @@ extension GenerateContentResponse.UsageMetadata: Decodable { case promptTokenCount case candidatesTokenCount case totalTokenCount + case promptTokensDetails + case candidatesTokensDetails } public init(from decoder: any Decoder) throws { @@ -307,6 +315,12 @@ extension GenerateContentResponse.UsageMetadata: Decodable { candidatesTokenCount = try container .decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0 totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0 + promptTokensDetails = try container + .decodeIfPresent([ModalityTokenCount].self, forKey: .promptTokensDetails) ?? + [ModalityTokenCount]() + candidatesTokensDetails = try container + .decodeIfPresent([ModalityTokenCount].self, forKey: .candidatesTokensDetails) ?? + [ModalityTokenCount]() } } diff --git a/FirebaseVertexAI/Sources/ModalityTokenCount.swift b/FirebaseVertexAI/Sources/ModalityTokenCount.swift new file mode 100644 index 00000000000..457e31d4109 --- /dev/null +++ b/FirebaseVertexAI/Sources/ModalityTokenCount.swift @@ -0,0 +1,61 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// Represents token counting info for a single modality. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ModalityTokenCount: Sendable { + /// The modality associated with this token count. + public let modality: ContentModality + + /// The number of tokens counted. + public let tokenCount: Int +} + +/// Content part modality. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ContentModality: DecodableProtoEnum, Hashable, Sendable { + enum Kind: String { + case text = "TEXT" + case image = "IMAGE" + case video = "VIDEO" + case audio = "AUDIO" + case document = "DOCUMENT" + } + + /// Plain text. + public static let text = ContentModality(kind: .text) + + /// Image. + public static let image = ContentModality(kind: .image) + + /// Video. + public static let video = ContentModality(kind: .video) + + /// Audio. + public static let audio = ContentModality(kind: .audio) + + /// Document, e.g. PDF. + public static let document = ContentModality(kind: .document) + + /// Returns the raw string representation of the `ContentModality` value. + public let rawValue: String + + static let unrecognizedValueMessageCode = + VertexLog.MessageCode.generateContentResponseUnrecognizedContentModality +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ModalityTokenCount: Decodable {} diff --git a/FirebaseVertexAI/Sources/VertexLog.swift b/FirebaseVertexAI/Sources/VertexLog.swift index 822908a4986..792d13358f6 100644 --- a/FirebaseVertexAI/Sources/VertexLog.swift +++ b/FirebaseVertexAI/Sources/VertexLog.swift @@ -57,6 +57,7 @@ enum VertexLog { case decodedInvalidProtoDateMonth = 3009 case decodedInvalidProtoDateDay = 3010 case decodedInvalidCitationPublicationDate = 3011 + case generateContentResponseUnrecognizedContentModality = 3012 // SDK State Errors case generateContentResponseNoCandidates = 4000 diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index e7571c844f3..3ed40ce2530 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -128,6 +128,30 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.functionCalls, []) } + func testGenerateContent_success_basicReplyFullUsageMetadata() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-response-long-usage-metadata", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + let finishReason = try XCTUnwrap(candidate.finishReason) + XCTAssertEqual(finishReason, .stop) + let usageMetadata = try XCTUnwrap(response.usageMetadata) + XCTAssertEqual(usageMetadata.promptTokensDetails.count, 2) + XCTAssertEqual(usageMetadata.promptTokensDetails[0].modality, .image) + XCTAssertEqual(usageMetadata.promptTokensDetails[0].tokenCount, 1806) + XCTAssertEqual(usageMetadata.promptTokensDetails[1].modality, .text) + XCTAssertEqual(usageMetadata.promptTokensDetails[1].tokenCount, 76) + XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1) + XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].modality, .text) + XCTAssertEqual(usageMetadata.candidatesTokensDetails[0].tokenCount, 76) + } + func testGenerateContent_success_citations() async throws { MockURLProtocol .requestHandler = try httpRequestHandler( @@ -488,6 +512,8 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(usageMetadata.promptTokenCount, 6) XCTAssertEqual(usageMetadata.candidatesTokenCount, 7) XCTAssertEqual(usageMetadata.totalTokenCount, 13) + XCTAssertEqual(usageMetadata.promptTokensDetails.isEmpty, true) + XCTAssertEqual(usageMetadata.candidatesTokensDetails.isEmpty, true) } func testGenerateContent_failure_invalidAPIKey() async throws { @@ -1326,6 +1352,23 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.totalBillableCharacters, 16) } + func testCountTokens_succeeds_detailed() async throws { + MockURLProtocol.requestHandler = try httpRequestHandler( + forResource: "unary-success-detailed-token-response", + withExtension: "json" + ) + + let response = try await model.countTokens("Why is the sky blue?") + + XCTAssertEqual(response.totalTokens, 1837) + XCTAssertEqual(response.totalBillableCharacters, 117) + XCTAssertEqual(response.promptTokensDetails.count, 2) + XCTAssertEqual(response.promptTokensDetails[0].modality, .image) + XCTAssertEqual(response.promptTokensDetails[0].tokenCount, 1806) + XCTAssertEqual(response.promptTokensDetails[1].modality, .text) + XCTAssertEqual(response.promptTokensDetails[1].tokenCount, 31) + } + func testCountTokens_succeeds_allOptions() async throws { MockURLProtocol.requestHandler = try httpRequestHandler( forResource: "unary-success-total-tokens",