From f7f09e20839d8a92da79918fc09f2706fabed6c6 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 20 Mar 2024 11:48:45 -0400 Subject: [PATCH 1/4] Make `generativeModel` an instance method of `VertexAI` --- .../ViewModels/ConversationViewModel.swift | 5 +- .../ViewModels/PhotoReasoningViewModel.swift | 5 +- .../ViewModels/SummarizeViewModel.swift | 5 +- FirebaseVertexAI/Sources/VertexAI.swift | 62 ++++++++----------- .../Sources/VertexAIComponent.swift | 10 +-- 5 files changed, 44 insertions(+), 43 deletions(-) diff --git a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift index 465789be387..05cbe11250f 100644 --- a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift +++ b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift @@ -36,7 +36,10 @@ class ConversationViewModel: ObservableObject { private var chatTask: Task? init() { - model = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1") + model = VertexAI.vertexAI().generativeModel( + modelName: "gemini-1.0-pro", + location: "us-central1" + ) chat = model.startChat() } diff --git a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index dc41e00444a..2f2ed88d4a1 100644 --- a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -44,7 +44,10 @@ class PhotoReasoningViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = VertexAI.generativeModel(modelName: "gemini-1.0-pro-vision", location: "us-central1") + model = VertexAI.vertexAI().generativeModel( + modelName: "gemini-1.0-pro-vision", + location: "us-central1" + ) } func reason() async { diff --git a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift index fb5b349ac82..0e3073d6da2 100644 --- a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift @@ -32,7 +32,10 @@ class SummarizeViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1") + model = VertexAI.vertexAI().generativeModel( + modelName: "gemini-1.0-pro", + location: "us-central1" + ) } func summarize(inputText: String) async { diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 0f79e1be99c..40b0ffc34d7 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -24,66 +24,58 @@ import Foundation open class VertexAI: NSObject { // MARK: - Public APIs - /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. - /// - /// This instance is configured with the default `FirebaseApp`. - /// - /// TODO: Add RequestOptions to public API. - public static func generativeModel(modelName: String, location: String) -> GenerativeModel { + public static func vertexAI() -> VertexAI { guard let app = FirebaseApp.app() else { fatalError("No instance of the default Firebase app was found.") } - return generativeModel(app: app, modelName: modelName, location: location) + + return vertexAI(app: app) } - /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. - /// - /// TODO: Add RequestOptions to public API. - public static func generativeModel(app: FirebaseApp, modelName: String, - location: String) -> GenerativeModel { + public static func vertexAI(app: FirebaseApp) -> VertexAI { guard let provider = ComponentType.instance(for: VertexAIProvider.self, in: app.container) else { fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)") } - let modelResourceName = modelResourceName(app: app, modelName: modelName, location: location) - let vertexAI = provider.vertexAI(location: location, modelResourceName: modelResourceName) - return vertexAI.model + return provider.vertexAI() } - // MARK: - Private - - /// The `FirebaseApp` associated with this `VertexAI` instance. - private let app: FirebaseApp - - private let appCheck: AppCheckInterop? - - private let location: String - - private let modelResouceName: String + /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. + public func generativeModel(modelName: String, location: String, + generationConfig: GenerationConfig? = nil, + safetySettings: [SafetySetting]? = nil, + requestOptions: RequestOptions = RequestOptions()) + -> GenerativeModel { + let modelResourceName = modelResourceName(modelName: modelName, location: location) - lazy var model: GenerativeModel = { guard let apiKey = app.options.apiKey else { fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.") } + return GenerativeModel( - name: modelResouceName, + name: modelResourceName, apiKey: apiKey, - // TODO: Add RequestOptions to public API. - requestOptions: RequestOptions(), + generationConfig: generationConfig, + safetySettings: safetySettings, + requestOptions: requestOptions, appCheck: appCheck ) - }() + } + + // MARK: - Private + + /// The `FirebaseApp` associated with this `VertexAI` instance. + private let app: FirebaseApp + + private let appCheck: AppCheckInterop? - init(app: FirebaseApp, location: String, modelResourceName: String) { + init(app: FirebaseApp) { self.app = app appCheck = ComponentType.instance(for: AppCheckInterop.self, in: app.container) - self.location = location - modelResouceName = modelResourceName } - private static func modelResourceName(app: FirebaseApp, modelName: String, - location: String) -> String { + private func modelResourceName(modelName: String, location: String) -> String { if modelName.contains("/") { return modelName } diff --git a/FirebaseVertexAI/Sources/VertexAIComponent.swift b/FirebaseVertexAI/Sources/VertexAIComponent.swift index a8d6c177c74..1378f812626 100644 --- a/FirebaseVertexAI/Sources/VertexAIComponent.swift +++ b/FirebaseVertexAI/Sources/VertexAIComponent.swift @@ -22,7 +22,7 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @objc(FIRVertexAIProvider) protocol VertexAIProvider { - @objc func vertexAI(location: String, modelResourceName: String) -> VertexAI + @objc func vertexAI() -> VertexAI } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @@ -64,17 +64,17 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider { // MARK: - VertexAIProvider conformance - func vertexAI(location: String, modelResourceName: String) -> VertexAI { + func vertexAI() -> VertexAI { os_unfair_lock_lock(&instancesLock) // Unlock before the function returns. defer { os_unfair_lock_unlock(&instancesLock) } - if let instance = instances[modelResourceName] { + if let instance = instances[app.name] { return instance } - let newInstance = VertexAI(app: app, location: location, modelResourceName: modelResourceName) - instances[modelResourceName] = newInstance + let newInstance = VertexAI(app: app) + instances[app.name] = newInstance return newInstance } } From 8cd56abee8dccc96db1c8f381785de103d0da603 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 20 Mar 2024 16:50:03 -0400 Subject: [PATCH 2/4] Add docs --- FirebaseVertexAI/Sources/VertexAI.swift | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 40b0ffc34d7..25c8eedff77 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -20,10 +20,12 @@ import Foundation @_implementationOnly import FirebaseCoreExtension @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) -@objc(FIRVertexAI) -open class VertexAI: NSObject { +public class VertexAI: NSObject { // MARK: - Public APIs + /// The default `VertexAI` instance. + /// + /// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`. public static func vertexAI() -> VertexAI { guard let app = FirebaseApp.app() else { fatalError("No instance of the default Firebase app was found.") @@ -32,6 +34,10 @@ open class VertexAI: NSObject { return vertexAI(app: app) } + /// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`. + /// + /// - Parameter app: The custom `FirebaseApp` used for initialization. + /// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`. public static func vertexAI(app: FirebaseApp) -> VertexAI { guard let provider = ComponentType.instance(for: VertexAIProvider.self, in: app.container) else { @@ -41,7 +47,18 @@ open class VertexAI: NSObject { return provider.vertexAI() } - /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. + /// Initializes a generative model with the given parameters. + /// + /// - Parameters: + /// - modelName: The name of the model to use, e.g., `"gemini-1.0-pro"`; see + /// [Gemini models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) + /// for a list of supported model names. + /// - location: The location identifier, e.g., `us-central1`; see + /// [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) + /// for a list of supported locations. + /// - generationConfig: The content generation parameters your model should use. + /// - safetySettings: A value describing what types of harmful content your model should allow. + /// - requestOptions: Configuration parameters for sending requests to the backend. public func generativeModel(modelName: String, location: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, From 4f94a44de03300326fd0519a91ab6bc04f756c5f Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 20 Mar 2024 16:58:45 -0400 Subject: [PATCH 3/4] Fix formatting --- FirebaseVertexAI/Sources/VertexAI.swift | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 25c8eedff77..64e76d3b4f5 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -51,10 +51,12 @@ public class VertexAI: NSObject { /// /// - Parameters: /// - modelName: The name of the model to use, e.g., `"gemini-1.0-pro"`; see - /// [Gemini models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) + /// [Gemini + /// models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) /// for a list of supported model names. /// - location: The location identifier, e.g., `us-central1`; see - /// [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) + /// [Vertex AI + /// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) /// for a list of supported locations. /// - generationConfig: The content generation parameters your model should use. /// - safetySettings: A value describing what types of harmful content your model should allow. From ee914b28f415e3a99e50dff7fa853bf7011d90f3 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Thu, 21 Mar 2024 15:31:49 -0400 Subject: [PATCH 4/4] Fix API tests --- .../Tests/Unit/VertexAIAPITests.swift | 49 +++++++++---------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index 4676b2e34d9..a1fd27ab4c5 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -33,40 +33,35 @@ final class VertexAIAPITests: XCTestCase { stopSequences: ["..."]) let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] + // Instantiate Vertex AI SDK - Default App + let vertexAI = VertexAI.vertexAI() + + // Instantiate Vertex AI SDK - Custom App + let _ = VertexAI.vertexAI(app: app!) + // Permutations without optional arguments. - // TODO: Change `genAI` to `_` when safetySettings and generationConfig are added to public API. - let genAI = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1") - let _ = VertexAI.generativeModel( - app: app!, + let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1") + + let _ = vertexAI.generativeModel( modelName: "gemini-1.0-pro", - location: "us-central1" + location: "us-central1", + safetySettings: filters ) - // TODO: Add safetySettings to public API. - // TODO: Add permutation with `app` specified. - // let _ = VertexAI.generativeModel( - // modelName: "gemini-1.0-pro", - // location: "us-central1", - // safetySettings: filters - // ) - // TODO: Add generationConfig to public API. - // TODO: Add permutation with `app` specified. - // let _ = VertexAI.generativeModel( - // modelName: "gemini-1.0-pro", - // location: "us-central1", - // generationConfig: config - // ) + let _ = vertexAI.generativeModel( + modelName: "gemini-1.0-pro", + location: "us-central1", + generationConfig: config + ) // All arguments passed. - // TODO: Add safetySettings and generationConfig to public API. - // TODO: Add permutation with `app` specified. - // let genAI = VertexAI.generativeModel( - // modelName: "gemini-1.0-pro", - // location: "us-central1", - // generationConfig: config, // Optional - // safetySettings: filters // Optional - // ) + let genAI = vertexAI.generativeModel( + modelName: "gemini-1.0-pro", + location: "us-central1", + generationConfig: config, // Optional + safetySettings: filters // Optional + ) // Full Typed Usage let pngData = Data() // ....