From 186edad11d71eeef2dbf25f2081b492f90e92d02 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 20 Jan 2025 16:54:04 -0500 Subject: [PATCH] [Vertex AI] Add `apiVersion` parameter to `RequestOptions` --- .../firebase/vertexai/type/ApiVersion.kt | 28 ++++++++++++ .../firebase/vertexai/type/RequestOptions.kt | 14 ++++-- .../vertexai/GenerativeModelTesting.kt | 2 +- .../vertexai/common/APIControllerTests.kt | 44 ++++++++++++++++++- 4 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ApiVersion.kt diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ApiVersion.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ApiVersion.kt new file mode 100644 index 00000000000..c8b11fd4451 --- /dev/null +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ApiVersion.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.type + +/** Versions of the Vertex AI in Firebase server API. */ +public class ApiVersion private constructor(internal val value: String) { + public companion object { + /** The stable channel for version 1 of the API. */ + @JvmField public val V1: ApiVersion = ApiVersion("v1") + + /** The beta channel for version 1 of the API. */ + @JvmField public val V1BETA: ApiVersion = ApiVersion("v1beta") + } +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt index 9aa648b6d07..02f9b569357 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt @@ -26,17 +26,23 @@ public class RequestOptions internal constructor( internal val timeout: Duration, internal val endpoint: String = "https://firebasevertexai.googleapis.com", - internal val apiVersion: String = "v1beta", + internal val apiVersion: String, ) { /** * Constructor for RequestOptions. * * @param timeoutInMillis the maximum amount of time, in milliseconds, for a request to take, from - * the first request to first response. + * the first request to first response. + * @param apiVersion the version of the Vertex AI in Firebase API; defaults to [ApiVersion.V1BETA] + * if not specified. */ @JvmOverloads public constructor( - timeoutInMillis: Long = 180.seconds.inWholeMilliseconds - ) : this(timeout = timeoutInMillis.toDuration(DurationUnit.MILLISECONDS)) + timeoutInMillis: Long = 180.seconds.inWholeMilliseconds, + apiVersion: ApiVersion = ApiVersion.V1BETA, + ) : this( + timeout = timeoutInMillis.toDuration(DurationUnit.MILLISECONDS), + apiVersion = apiVersion.value, + ) } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt index 8b668371a31..f20ab347c3a 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt @@ -60,7 +60,7 @@ internal class GenerativeModelTesting { APIController( "super_cool_test_key", "gemini-1.5-flash", - RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"), + RequestOptions(5.seconds.inWholeMilliseconds), mockEngine, TEST_CLIENT_ID, null, diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index 8937b13569b..49169ff6425 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -26,11 +26,14 @@ import com.google.firebase.vertexai.common.util.commonTest import com.google.firebase.vertexai.common.util.createResponses import com.google.firebase.vertexai.common.util.doBlocking import com.google.firebase.vertexai.common.util.prepareStreamingResponse +import com.google.firebase.vertexai.type.ApiVersion import com.google.firebase.vertexai.type.RequestOptions import io.kotest.assertions.json.shouldContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe import io.kotest.matchers.string.shouldContain +import io.kotest.matchers.string.shouldStartWith import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond import io.ktor.content.TextContent @@ -74,7 +77,7 @@ internal class APIControllerTests { @Test fun `(generateContent) respects a custom timeout`() = - commonTest(requestOptions = RequestOptions(2.seconds)) { + commonTest(requestOptions = RequestOptions(2.seconds.inWholeMilliseconds)) { shouldThrow { withTimeout(testTimeout) { apiController.generateContent(textGenerateContentRequest("test")) @@ -122,7 +125,11 @@ internal class RequestFormatTests { APIController( "super_cool_test_key", "gemini-pro-1.5", - RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"), + RequestOptions( + timeout = 5.seconds, + endpoint = "https://my.custom.endpoint", + apiVersion = "v1beta" + ), mockEngine, TEST_CLIENT_ID, null, @@ -138,6 +145,39 @@ internal class RequestFormatTests { mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint" } + @Test + fun `using custom API version`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val controller = + APIController( + "super_cool_test_key", + "gemini-pro-1.5", + RequestOptions( + timeoutInMillis = 5.seconds.inWholeMilliseconds, + apiVersion = ApiVersion.V1 + ), + mockEngine, + TEST_CLIENT_ID, + null, + ) + + withTimeout(5.seconds) { + controller.generateContentStream(textGenerateContentRequest("cats")).collect { + it.candidates?.isEmpty() shouldBe false + channel.close() + } + } + + mockEngine.requestHistory.first().url.encodedPath shouldStartWith "/${ApiVersion.V1.value}" + // TODO: Update test to set ApiVersion.V1BETA when ApiVersion.V1 becomes the default and delete + // the following check. + RequestOptions().apiVersion shouldNotBe ApiVersion.V1.value + } + @Test fun `client id header is set correctly in the request`() = doBlocking { val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))