Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Add ImagenModelConfig for model-level config params #14315

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,11 @@ public struct ImagenGenerationConfig {
public var numberOfImages: Int?
public var negativePrompt: String?
public var aspectRatio: ImagenAspectRatio?
public var imageFormat: ImagenImageFormat?
public var addWatermark: Bool?

public init(numberOfImages: Int? = nil,
negativePrompt: String? = nil,
aspectRatio: ImagenAspectRatio? = nil,
imageFormat: ImagenImageFormat? = nil,
addWatermark: Bool? = nil) {
public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
aspectRatio: ImagenAspectRatio? = nil) {
self.numberOfImages = numberOfImages
self.negativePrompt = negativePrompt
self.aspectRatio = aspectRatio
self.imageFormat = imageFormat
self.addWatermark = addWatermark
}
}
11 changes: 9 additions & 2 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let modelConfig: ImagenModelConfig?

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
Expand All @@ -32,6 +34,7 @@ public final class ImagenModel {
init(name: String,
projectID: String,
apiKey: String,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
Expand All @@ -45,6 +48,7 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.modelConfig = modelConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}
Expand All @@ -57,6 +61,7 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -70,6 +75,7 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -90,6 +96,7 @@ public final class ImagenModel {

static func imageGenerationParameters(storageURI: String?,
generationConfig: ImagenGenerationConfig?,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?)
-> ImageGenerationParameters {
return ImageGenerationParameters(
Expand All @@ -99,13 +106,13 @@ public final class ImagenModel {
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
personGeneration: safetySettings?.personFilterLevel?.rawValue,
outputOptions: generationConfig?.imageFormat.map {
outputOptions: modelConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: generationConfig?.addWatermark,
addWatermark: modelConfig?.addWatermark,
includeResponsibleAIFilterReason: true
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenModelConfig {
let imageFormat: ImagenImageFormat?
let addWatermark: Bool?

public init(imageFormat: ImagenImageFormat? = nil, addWatermark: Bool? = nil) {
self.imageFormat = imageFormat
self.addWatermark = addWatermark
}
}
4 changes: 3 additions & 1 deletion FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
public func imagenModel(modelName: String, modelConfig: ImagenModelConfig? = nil,
safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
modelConfig: modelConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ final class IntegrationTests: XCTestCase {
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
modelConfig: ImagenModelConfig(imageFormat: .jpeg(compressionQuality: 70)),
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .blockAll
Expand Down Expand Up @@ -253,9 +254,7 @@ final class IntegrationTests: XCTestCase {
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""
var generationConfig = ImagenGenerationConfig()
generationConfig.aspectRatio = .landscape16x9
generationConfig.imageFormat = .jpeg(compressionQuality: 70)
let generationConfig = ImagenGenerationConfig(aspectRatio: .landscape16x9)

let response = try await imagenModel.generateImages(
prompt: imagePrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

Expand All @@ -63,6 +64,37 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
}

func testParameters_includeModelConfig() throws {
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let addWatermark = true
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: modelConfig,
safetySettings: nil
)

Expand All @@ -73,15 +105,10 @@ final class ImageGenerationParametersTests: XCTestCase {
let sampleCount = 2
let negativePrompt = "test-negative-prompt"
let aspectRatio = ImagenAspectRatio.landscape16x9
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let addWatermark = true
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
aspectRatio: aspectRatio
)
let expectedParameters = ImageGenerationParameters(
sampleCount: sampleCount,
Expand All @@ -90,24 +117,20 @@ final class ImageGenerationParametersTests: XCTestCase {
aspectRatio: aspectRatio.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
XCTAssertEqual(parameters.aspectRatio, "16:9")
XCTAssertEqual(parameters.outputOptions?.mimeType, "image/jpeg")
XCTAssertEqual(parameters.outputOptions?.compressionQuality, compressionQuality)
}

func testDefaultParameters_includeSafetySettings() throws {
Expand All @@ -132,6 +155,7 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: safetySettings
)

Expand All @@ -145,15 +169,14 @@ final class ImageGenerationParametersTests: XCTestCase {
let sampleCount = 4
let negativePrompt = "test-negative-prompt"
let aspectRatio = ImagenAspectRatio.portrait3x4
let imageFormat = ImagenImageFormat.png()
let addWatermark = false
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
aspectRatio: aspectRatio
)
let imageFormat = ImagenImageFormat.png()
let addWatermark = false
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
let personFilterLevel = ImagenPersonFilterLevel.blockAll
let safetySettings = ImagenSafetySettings(
Expand All @@ -178,6 +201,7 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)

Expand Down
Loading