Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make generativeModel an instance method of VertexAI #12599

Merged
merged 4 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class ConversationViewModel: ObservableObject {
private var chatTask: Task<Void, Never>?

init() {
model = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1"
)
chat = model.startChat()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class PhotoReasoningViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = VertexAI.generativeModel(modelName: "gemini-1.0-pro-vision", location: "us-central1")
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro-vision",
location: "us-central1"
)
}

func reason() async {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class SummarizeViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1"
)
}

func summarize(inputText: String) async {
Expand Down
81 changes: 46 additions & 35 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,70 +20,81 @@ import Foundation
@_implementationOnly import FirebaseCoreExtension

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
@objc(FIRVertexAI)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be objc to be initialized by the Component system?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. I was seeing Cannot convert value of type '(any VertexAIProvider).Type' to expected argument type 'Protocol' when I dropped the @objcs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it is dropped here?

open class VertexAI: NSObject {
public class VertexAI: NSObject {
// MARK: - Public APIs

/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
/// The default `VertexAI` instance.
///
/// This instance is configured with the default `FirebaseApp`.
///
/// TODO: Add RequestOptions to public API.
public static func generativeModel(modelName: String, location: String) -> GenerativeModel {
/// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`.
public static func vertexAI() -> VertexAI {
guard let app = FirebaseApp.app() else {
fatalError("No instance of the default Firebase app was found.")
}
return generativeModel(app: app, modelName: modelName, location: location)

return vertexAI(app: app)
}

/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
/// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`.
///
/// TODO: Add RequestOptions to public API.
public static func generativeModel(app: FirebaseApp, modelName: String,
location: String) -> GenerativeModel {
/// - Parameter app: The custom `FirebaseApp` used for initialization.
/// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`.
public static func vertexAI(app: FirebaseApp) -> VertexAI {
guard let provider = ComponentType<VertexAIProvider>.instance(for: VertexAIProvider.self,
in: app.container) else {
fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)")
}
let modelResourceName = modelResourceName(app: app, modelName: modelName, location: location)
let vertexAI = provider.vertexAI(location: location, modelResourceName: modelResourceName)

return vertexAI.model
return provider.vertexAI()
}

// MARK: - Private

/// The `FirebaseApp` associated with this `VertexAI` instance.
private let app: FirebaseApp

private let appCheck: AppCheckInterop?

private let location: String

private let modelResouceName: String
/// Initializes a generative model with the given parameters.
///
/// - Parameters:
/// - modelName: The name of the model to use, e.g., `"gemini-1.0-pro"`; see
/// [Gemini
/// models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models)
/// for a list of supported model names.
/// - location: The location identifier, e.g., `us-central1`; see
/// [Vertex AI
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
/// for a list of supported locations.
/// - generationConfig: The content generation parameters your model should use.
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - requestOptions: Configuration parameters for sending requests to the backend.
public func generativeModel(modelName: String, location: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
requestOptions: RequestOptions = RequestOptions())
-> GenerativeModel {
let modelResourceName = modelResourceName(modelName: modelName, location: location)

lazy var model: GenerativeModel = {
guard let apiKey = app.options.apiKey else {
fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.")
}

return GenerativeModel(
name: modelResouceName,
name: modelResourceName,
apiKey: apiKey,
// TODO: Add RequestOptions to public API.
requestOptions: RequestOptions(),
generationConfig: generationConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck
)
}()
}

// MARK: - Private

/// The `FirebaseApp` associated with this `VertexAI` instance.
private let app: FirebaseApp

private let appCheck: AppCheckInterop?

init(app: FirebaseApp, location: String, modelResourceName: String) {
init(app: FirebaseApp) {
self.app = app
appCheck = ComponentType<AppCheckInterop>.instance(for: AppCheckInterop.self, in: app.container)
self.location = location
modelResouceName = modelResourceName
}

private static func modelResourceName(app: FirebaseApp, modelName: String,
location: String) -> String {
private func modelResourceName(modelName: String, location: String) -> String {
if modelName.contains("/") {
return modelName
}
Expand Down
10 changes: 5 additions & 5 deletions FirebaseVertexAI/Sources/VertexAIComponent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Foundation
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
@objc(FIRVertexAIProvider)
protocol VertexAIProvider {
@objc func vertexAI(location: String, modelResourceName: String) -> VertexAI
@objc func vertexAI() -> VertexAI
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
Expand Down Expand Up @@ -64,17 +64,17 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {

// MARK: - VertexAIProvider conformance

func vertexAI(location: String, modelResourceName: String) -> VertexAI {
func vertexAI() -> VertexAI {
os_unfair_lock_lock(&instancesLock)

// Unlock before the function returns.
defer { os_unfair_lock_unlock(&instancesLock) }

if let instance = instances[modelResourceName] {
if let instance = instances[app.name] {
return instance
}
let newInstance = VertexAI(app: app, location: location, modelResourceName: modelResourceName)
instances[modelResourceName] = newInstance
let newInstance = VertexAI(app: app)
instances[app.name] = newInstance
return newInstance
}
}
49 changes: 22 additions & 27 deletions FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,35 @@ final class VertexAIAPITests: XCTestCase {
stopSequences: ["..."])
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]

// Instantiate Vertex AI SDK - Default App
let vertexAI = VertexAI.vertexAI()

// Instantiate Vertex AI SDK - Custom App
let _ = VertexAI.vertexAI(app: app!)

// Permutations without optional arguments.

// TODO: Change `genAI` to `_` when safetySettings and generationConfig are added to public API.
let genAI = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
let _ = VertexAI.generativeModel(
app: app!,
let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")

let _ = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1"
location: "us-central1",
safetySettings: filters
)

// TODO: Add safetySettings to public API.
// TODO: Add permutation with `app` specified.
// let _ = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// safetySettings: filters
// )
// TODO: Add generationConfig to public API.
// TODO: Add permutation with `app` specified.
// let _ = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// generationConfig: config
// )
let _ = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
generationConfig: config
)

// All arguments passed.
// TODO: Add safetySettings and generationConfig to public API.
// TODO: Add permutation with `app` specified.
// let genAI = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// generationConfig: config, // Optional
// safetySettings: filters // Optional
// )
let genAI = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
generationConfig: config, // Optional
safetySettings: filters // Optional
)

// Full Typed Usage
let pngData = Data() // ....
Expand Down
Loading