Skip to content

Commit

Permalink
Manage location on VertexAI instead of model
Browse files Browse the repository at this point in the history
  • Loading branch information
paulb777 committed Mar 26, 2024
1 parent a315fdf commit c7a8593
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ class ConversationViewModel: ObservableObject {
private var chatTask: Task<Void, Never>?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1"
)
model = VertexAI.vertexAI(location: "us-central1").generativeModel(modelName: "gemini-1.0-pro")
chat = model.startChat()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ class PhotoReasoningViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro-vision",
location: "us-central1"
)
model = VertexAI.vertexAI(location: "us-central1").generativeModel(modelName: "gemini-1.0-pro")
}

func reason() async {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ class SummarizeViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1"
)
model = VertexAI.vertexAI(location: "us-central1").generativeModel(modelName: "gemini-1.0-pro")
}

func summarize(inputText: String) async {
Expand Down
30 changes: 19 additions & 11 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,35 @@ public class VertexAI: NSObject {

/// The default `VertexAI` instance.
///
/// - Parameter location: The location identifier, e.g., `us-central1`; see
/// [Vertex AI
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
/// for a list of supported locations.
/// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`.
public static func vertexAI() -> VertexAI {
public static func vertexAI(location: String) -> VertexAI {
guard let app = FirebaseApp.app() else {
fatalError("No instance of the default Firebase app was found.")
}

return vertexAI(app: app)
return vertexAI(app: app, location: location)
}

/// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`.
///
/// - Parameter app: The custom `FirebaseApp` used for initialization.
/// - Parameters:
/// - app: The custom `FirebaseApp` used for initialization.
/// - location: The location identifier, e.g., `us-central1`; see
/// [Vertex AI
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
/// for a list of supported locations.
/// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`.
public static func vertexAI(app: FirebaseApp) -> VertexAI {
public static func vertexAI(app: FirebaseApp, location: String) -> VertexAI {
guard let provider = ComponentType<VertexAIProvider>.instance(for: VertexAIProvider.self,
in: app.container) else {
fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)")
}

return provider.vertexAI()
return provider.vertexAI(location)
}

/// Initializes a generative model with the given parameters.
Expand All @@ -54,14 +63,10 @@ public class VertexAI: NSObject {
/// [Gemini
/// models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models)
/// for a list of supported model names.
/// - location: The location identifier, e.g., `us-central1`; see
/// [Vertex AI
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
/// for a list of supported locations.
/// - generationConfig: The content generation parameters your model should use.
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - requestOptions: Configuration parameters for sending requests to the backend.
public func generativeModel(modelName: String, location: String,
public func generativeModel(modelName: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
requestOptions: RequestOptions = RequestOptions())
Expand Down Expand Up @@ -89,8 +94,11 @@ public class VertexAI: NSObject {

private let appCheck: AppCheckInterop?

init(app: FirebaseApp) {
private let location: String

init(app: FirebaseApp, location: String) {
self.app = app
self.location = location
appCheck = ComponentType<AppCheckInterop>.instance(for: AppCheckInterop.self, in: app.container)
}

Expand Down
6 changes: 3 additions & 3 deletions FirebaseVertexAI/Sources/VertexAIComponent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Foundation
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
@objc(FIRVertexAIProvider)
protocol VertexAIProvider {
@objc func vertexAI() -> VertexAI
@objc func vertexAI(_ location: String) -> VertexAI
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
Expand Down Expand Up @@ -64,7 +64,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {

// MARK: - VertexAIProvider conformance

func vertexAI() -> VertexAI {
func vertexAI(_ location: String) -> VertexAI {
os_unfair_lock_lock(&instancesLock)

// Unlock before the function returns.
Expand All @@ -73,7 +73,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {
if let instance = instances[app.name] {
return instance
}
let newInstance = VertexAI(app: app)
let newInstance = VertexAI(app: app, location: location)
instances[app.name] = newInstance
return newInstance
}
Expand Down
9 changes: 3 additions & 6 deletions FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,28 @@ final class VertexAIAPITests: XCTestCase {
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]

// Instantiate Vertex AI SDK - Default App
let vertexAI = VertexAI.vertexAI()
let vertexAI = VertexAI.vertexAI(location: "my-location")

// Instantiate Vertex AI SDK - Custom App
let _ = VertexAI.vertexAI(app: app!)
let _ = VertexAI.vertexAI(app: app!, location: "my-location")

// Permutations without optional arguments.

let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro")

let _ = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
safetySettings: filters
)

let _ = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
generationConfig: config
)

// All arguments passed.
let genAI = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
generationConfig: config, // Optional
safetySettings: filters // Optional
)
Expand Down

0 comments on commit c7a8593

Please sign in to comment.