Skip to content

Commit

Permalink
[Vertex AI] Add apiVersion parameter to RequestOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Jan 20, 2025
1 parent 8e61ec2 commit 186edad
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai.type

/** Versions of the Vertex AI in Firebase server API. */
public class ApiVersion private constructor(internal val value: String) {
public companion object {
/** The stable channel for version 1 of the API. */
@JvmField public val V1: ApiVersion = ApiVersion("v1")

/** The beta channel for version 1 of the API. */
@JvmField public val V1BETA: ApiVersion = ApiVersion("v1beta")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,23 @@ public class RequestOptions
internal constructor(
internal val timeout: Duration,
internal val endpoint: String = "https://firebasevertexai.googleapis.com",
internal val apiVersion: String = "v1beta",
internal val apiVersion: String,
) {

/**
* Constructor for RequestOptions.
*
* @param timeoutInMillis the maximum amount of time, in milliseconds, for a request to take, from
* the first request to first response.
* the first request to first response.
* @param apiVersion the version of the Vertex AI in Firebase API; defaults to [ApiVersion.V1BETA]
* if not specified.
*/
@JvmOverloads
public constructor(
timeoutInMillis: Long = 180.seconds.inWholeMilliseconds
) : this(timeout = timeoutInMillis.toDuration(DurationUnit.MILLISECONDS))
timeoutInMillis: Long = 180.seconds.inWholeMilliseconds,
apiVersion: ApiVersion = ApiVersion.V1BETA,
) : this(
timeout = timeoutInMillis.toDuration(DurationUnit.MILLISECONDS),
apiVersion = apiVersion.value,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ internal class GenerativeModelTesting {
APIController(
"super_cool_test_key",
"gemini-1.5-flash",
RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"),
RequestOptions(5.seconds.inWholeMilliseconds),
mockEngine,
TEST_CLIENT_ID,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ import com.google.firebase.vertexai.common.util.commonTest
import com.google.firebase.vertexai.common.util.createResponses
import com.google.firebase.vertexai.common.util.doBlocking
import com.google.firebase.vertexai.common.util.prepareStreamingResponse
import com.google.firebase.vertexai.type.ApiVersion
import com.google.firebase.vertexai.type.RequestOptions
import io.kotest.assertions.json.shouldContainJsonKey
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldStartWith
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.content.TextContent
Expand Down Expand Up @@ -74,7 +77,7 @@ internal class APIControllerTests {

@Test
fun `(generateContent) respects a custom timeout`() =
commonTest(requestOptions = RequestOptions(2.seconds)) {
commonTest(requestOptions = RequestOptions(2.seconds.inWholeMilliseconds)) {
shouldThrow<RequestTimeoutException> {
withTimeout(testTimeout) {
apiController.generateContent(textGenerateContentRequest("test"))
Expand Down Expand Up @@ -122,7 +125,11 @@ internal class RequestFormatTests {
APIController(
"super_cool_test_key",
"gemini-pro-1.5",
RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"),
RequestOptions(
timeout = 5.seconds,
endpoint = "https://my.custom.endpoint",
apiVersion = "v1beta"
),
mockEngine,
TEST_CLIENT_ID,
null,
Expand All @@ -138,6 +145,39 @@ internal class RequestFormatTests {
mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}

@Test
fun `using custom API version`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.5",
RequestOptions(
timeoutInMillis = 5.seconds.inWholeMilliseconds,
apiVersion = ApiVersion.V1
),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.encodedPath shouldStartWith "/${ApiVersion.V1.value}"
// TODO: Update test to set ApiVersion.V1BETA when ApiVersion.V1 becomes the default and delete
// the following check.
RequestOptions().apiVersion shouldNotBe ApiVersion.V1.value
}

@Test
fun `client id header is set correctly in the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
Expand Down

0 comments on commit 186edad

Please sign in to comment.