From c7a85938152338c3e2bb7393ea1380e449a59383 Mon Sep 17 00:00:00 2001 From: Paul Beusterien Date: Tue, 26 Mar 2024 11:47:29 -0700 Subject: [PATCH] Manage location on VertexAI instead of model --- .../ViewModels/ConversationViewModel.swift | 5 +--- .../ViewModels/PhotoReasoningViewModel.swift | 5 +--- .../ViewModels/SummarizeViewModel.swift | 5 +--- FirebaseVertexAI/Sources/VertexAI.swift | 30 ++++++++++++------- .../Sources/VertexAIComponent.swift | 6 ++-- .../Tests/Unit/VertexAIAPITests.swift | 9 ++---- 6 files changed, 28 insertions(+), 32 deletions(-) diff --git a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift index 05cbe11250f..883cefb359f 100644 --- a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift +++ b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift @@ -36,10 +36,7 @@ class ConversationViewModel: ObservableObject { private var chatTask: Task? init() { - model = VertexAI.vertexAI().generativeModel( - modelName: "gemini-1.0-pro", - location: "us-central1" - ) + model = VertexAI.vertexAI(location: "us-central1").generativeModel(modelName: "gemini-1.0-pro") chat = model.startChat() } diff --git a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index 2f2ed88d4a1..0fe81277a37 100644 --- a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -44,10 +44,7 @@ class PhotoReasoningViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = VertexAI.vertexAI().generativeModel( - modelName: "gemini-1.0-pro-vision", - location: "us-central1" - ) + model = VertexAI.vertexAI(location: "us-central1").generativeModel(modelName: "gemini-1.0-pro") } func reason() async { diff --git a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift index 0e3073d6da2..a90e1cf15b8 100644 --- a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift @@ -32,10 +32,7 @@ class SummarizeViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = VertexAI.vertexAI().generativeModel( - modelName: "gemini-1.0-pro", - location: "us-central1" - ) + model = VertexAI.vertexAI(location: "us-central1").generativeModel(modelName: "gemini-1.0-pro") } func summarize(inputText: String) async { diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 64e76d3b4f5..1186616f119 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -25,26 +25,35 @@ public class VertexAI: NSObject { /// The default `VertexAI` instance. /// + /// - Parameter location: The location identifier, e.g., `us-central1`; see + /// [Vertex AI + /// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) + /// for a list of supported locations. /// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`. - public static func vertexAI() -> VertexAI { + public static func vertexAI(location: String) -> VertexAI { guard let app = FirebaseApp.app() else { fatalError("No instance of the default Firebase app was found.") } - return vertexAI(app: app) + return vertexAI(app: app, location: location) } /// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`. /// - /// - Parameter app: The custom `FirebaseApp` used for initialization. + /// - Parameters: + /// - app: The custom `FirebaseApp` used for initialization. + /// - location: The location identifier, e.g., `us-central1`; see + /// [Vertex AI + /// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) + /// for a list of supported locations. /// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`. - public static func vertexAI(app: FirebaseApp) -> VertexAI { + public static func vertexAI(app: FirebaseApp, location: String) -> VertexAI { guard let provider = ComponentType.instance(for: VertexAIProvider.self, in: app.container) else { fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)") } - return provider.vertexAI() + return provider.vertexAI(location) } /// Initializes a generative model with the given parameters. @@ -54,14 +63,10 @@ public class VertexAI: NSObject { /// [Gemini /// models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) /// for a list of supported model names. - /// - location: The location identifier, e.g., `us-central1`; see - /// [Vertex AI - /// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) - /// for a list of supported locations. /// - generationConfig: The content generation parameters your model should use. /// - safetySettings: A value describing what types of harmful content your model should allow. /// - requestOptions: Configuration parameters for sending requests to the backend. - public func generativeModel(modelName: String, location: String, + public func generativeModel(modelName: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, requestOptions: RequestOptions = RequestOptions()) @@ -89,8 +94,11 @@ public class VertexAI: NSObject { private let appCheck: AppCheckInterop? - init(app: FirebaseApp) { + private let location: String + + init(app: FirebaseApp, location: String) { self.app = app + self.location = location appCheck = ComponentType.instance(for: AppCheckInterop.self, in: app.container) } diff --git a/FirebaseVertexAI/Sources/VertexAIComponent.swift b/FirebaseVertexAI/Sources/VertexAIComponent.swift index 1378f812626..5a8def8117f 100644 --- a/FirebaseVertexAI/Sources/VertexAIComponent.swift +++ b/FirebaseVertexAI/Sources/VertexAIComponent.swift @@ -22,7 +22,7 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @objc(FIRVertexAIProvider) protocol VertexAIProvider { - @objc func vertexAI() -> VertexAI + @objc func vertexAI(_ location: String) -> VertexAI } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @@ -64,7 +64,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider { // MARK: - VertexAIProvider conformance - func vertexAI() -> VertexAI { + func vertexAI(_ location: String) -> VertexAI { os_unfair_lock_lock(&instancesLock) // Unlock before the function returns. @@ -73,7 +73,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider { if let instance = instances[app.name] { return instance } - let newInstance = VertexAI(app: app) + let newInstance = VertexAI(app: app, location: location) instances[app.name] = newInstance return newInstance } diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index a1fd27ab4c5..80236f405de 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -34,31 +34,28 @@ final class VertexAIAPITests: XCTestCase { let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] // Instantiate Vertex AI SDK - Default App - let vertexAI = VertexAI.vertexAI() + let vertexAI = VertexAI.vertexAI(location: "my-location") // Instantiate Vertex AI SDK - Custom App - let _ = VertexAI.vertexAI(app: app!) + let _ = VertexAI.vertexAI(app: app!, location: "my-location") // Permutations without optional arguments. - let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1") + let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro") let _ = vertexAI.generativeModel( modelName: "gemini-1.0-pro", - location: "us-central1", safetySettings: filters ) let _ = vertexAI.generativeModel( modelName: "gemini-1.0-pro", - location: "us-central1", generationConfig: config ) // All arguments passed. let genAI = vertexAI.generativeModel( modelName: "gemini-1.0-pro", - location: "us-central1", generationConfig: config, // Optional safetySettings: filters // Optional )