-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add prototype for Firebase Vertex AI
- Loading branch information
1 parent
939a521
commit 42df2e3
Showing
4 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
|
||
import FirebaseAppCheckInterop | ||
import FirebaseCore | ||
import GoogleGenerativeAI | ||
|
||
// Avoids exposing internal FirebaseCore APIs to Swift users. | ||
@_implementationOnly import FirebaseCoreExtension | ||
|
||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
@objc(FIRVertexAI) | ||
open class VertexAI: NSObject { | ||
// MARK: - Public APIs | ||
|
||
/// The default `VertexAI` instance. | ||
/// | ||
/// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`. | ||
public static func vertexAI(modelName: String, location: String) -> VertexAI { | ||
return vertexAI(app: FirebaseApp.app()!, modelName: modelName, location: location) | ||
} | ||
|
||
public static func vertexAI(app: FirebaseApp, modelName: String, location: String) -> VertexAI { | ||
let provider = ComponentType<VertexAIProvider>.instance(for: VertexAIProvider.self, | ||
in: app.container) | ||
let modelResourceName = modelResourceName(app: app, modelName: modelName, location: location) | ||
return provider.vertexAI(location: location, modelResourceName: modelResourceName) | ||
} | ||
|
||
public func generateContentStream(prompt: String) async | ||
-> AsyncThrowingStream<GenerateContentResponse, Error> { | ||
return model.generateContentStream(prompt) | ||
} | ||
|
||
// MARK: - Private | ||
|
||
/// The `FirebaseApp` associated with this `VertexAI` instance. | ||
private let app: FirebaseApp | ||
|
||
private let appCheck: AppCheckInterop? | ||
|
||
private let location: String | ||
|
||
private let modelResouceName: String | ||
|
||
lazy var model: GenerativeModel = { | ||
let options = RequestOptions(hooks: [ | ||
setVertexAIEndpoint, | ||
addAccessTokenHeader, | ||
addAppCheckHeader, | ||
]) | ||
return GenerativeModel( | ||
name: modelResouceName, | ||
apiKey: app.options.apiKey!, | ||
requestOptions: options | ||
) | ||
}() | ||
|
||
private static let accessTokenEnvKey = "FIRVertexAIAccessToken" | ||
|
||
init(app: FirebaseApp, location: String, modelResourceName: String) { | ||
self.app = app | ||
appCheck = ComponentType<AppCheckInterop>.instance(for: AppCheckInterop.self, in: app.container) | ||
self.location = location | ||
modelResouceName = modelResourceName | ||
} | ||
|
||
private static func modelResourceName(app: FirebaseApp, modelName: String, | ||
location: String) -> String { | ||
if modelName.contains("/") { | ||
return modelName | ||
} | ||
guard let projectID = app.options.projectID else { | ||
print("The FirebaseApp is missing a project ID.") | ||
return modelName | ||
} | ||
|
||
return "projects/\(projectID)/locations/\(location)/publishers/google/models/\(modelName)" | ||
} | ||
|
||
// MARK: Request Hooks | ||
|
||
/// Replace the Labs endpoint with a Vertex AI endpoint in the provided request. | ||
/// | ||
/// This is temporary workaround until the Google Generative AI SDK supports setting an endpoint. | ||
/// | ||
/// - Parameter request: The `URLRequest` to modify with a Vertex AI hostname. | ||
func setVertexAIEndpoint(request: inout URLRequest) { | ||
guard let requestURL = request.url else { | ||
return | ||
} | ||
guard var urlComponents = URLComponents(url: requestURL, resolvingAgainstBaseURL: false) else { | ||
return | ||
} | ||
urlComponents.host = "\(location)-aiplatform.googleapis.com" | ||
|
||
guard let componentsURL = urlComponents.url else { | ||
return | ||
} | ||
|
||
request.url = componentsURL | ||
} | ||
|
||
/// Add a Google Cloud access token in an Authorization header in the provided request. | ||
/// | ||
/// This is a temporary workaround until Vertex AI can be called with an API key. | ||
/// | ||
/// - Parameter request: The `URLRequest` to modify by adding an access token. | ||
func addAccessTokenHeader(request: inout URLRequest) { | ||
// Remove the API key header, it is not supported by Vertex AI. | ||
if var headers = request.allHTTPHeaderFields { | ||
headers.removeValue(forKey: "x-goog-api-key") | ||
} | ||
|
||
guard let accessToken = ProcessInfo.processInfo.environment[VertexAI.accessTokenEnvKey] else { | ||
print(""" | ||
Vertex AI requires an Access Token for authorization: | ||
1. Get an access token by running `gcloud auth print-access-token` | ||
2. Set it in the \(VertexAI.accessTokenEnvKey) environment variable. | ||
""") | ||
return | ||
} | ||
|
||
request.addValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization") | ||
} | ||
|
||
/// Adds an App Check token to the provided request, if App Check is included in the app. | ||
/// | ||
/// This demonstrates how an App Check token can be added to requests; it is currently ignored by | ||
/// the backend. | ||
/// | ||
/// - Parameter request: The `URLRequest` to modify by adding an App Check token header. | ||
func addAppCheckHeader(request: inout URLRequest) async { | ||
guard let appCheck = appCheck else { | ||
return | ||
} | ||
|
||
let tokenResult = await appCheck.getToken(forcingRefresh: false) | ||
request.addValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import FirebaseAppCheckInterop | ||
import FirebaseCore | ||
import Foundation | ||
|
||
// Avoids exposing internal FirebaseCore APIs to Swift users. | ||
@_implementationOnly import FirebaseCoreExtension | ||
|
||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
@objc(FIRVertexAIProvider) | ||
protocol VertexAIProvider { | ||
@objc func vertexAI(location: String, modelResourceName: String) -> VertexAI | ||
} | ||
|
||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
@objc(FIRVertexAIComponent) | ||
class VertexAIComponent: NSObject, Library, VertexAIProvider { | ||
// MARK: - Private Variables | ||
|
||
/// The app associated with all `VertexAI` instances in this container. | ||
/// This is `unowned` instead of `weak` so it can be used without unwrapping in `vertexAI(...)` | ||
private unowned let app: FirebaseApp | ||
|
||
/// A map of active `VertexAI` instances for `app`, keyed by model resource names | ||
/// (e.g., "projects/my-project-id/locations/us-central1/publishers/google/models/gemini-pro"). | ||
private var instances: [String: VertexAI] = [:] | ||
|
||
/// Lock to manage access to the `instances` array to avoid race conditions. | ||
private var instancesLock: os_unfair_lock = .init() | ||
|
||
// MARK: - Initializers | ||
|
||
required init(app: FirebaseApp) { | ||
self.app = app | ||
} | ||
|
||
// MARK: - Library conformance | ||
|
||
static func componentsToRegister() -> [Component] { | ||
let appCheckInterop = Dependency(with: AppCheckInterop.self, isRequired: false) | ||
return [Component(VertexAIProvider.self, | ||
instantiationTiming: .lazy, | ||
dependencies: [ | ||
appCheckInterop, | ||
]) { container, isCacheable in | ||
guard let app = container.app else { return nil } | ||
isCacheable.pointee = true | ||
return self.init(app: app) | ||
}] | ||
} | ||
|
||
// MARK: - VertexAIProvider conformance | ||
|
||
func vertexAI(location: String, modelResourceName: String) -> VertexAI { | ||
os_unfair_lock_lock(&instancesLock) | ||
|
||
// Unlock before the function returns. | ||
defer { os_unfair_lock_unlock(&instancesLock) } | ||
|
||
if let instance = instances[modelResourceName] { | ||
return instance | ||
} | ||
let newInstance = VertexAI(app: app, location: location, modelResourceName: modelResourceName) | ||
instances[modelResourceName] = newInstance | ||
return newInstance | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters