Skip to content

Commit

Permalink
Migrate away from enums (#6340)
Browse files Browse the repository at this point in the history
Per [b/370771226](https://b.corp.google.com/issues/370771226),

This refactors all our enums in vertex to be classes instead. While this
means no more exhaustive `when`, this allows us to add new values in the
future without breaking the API.

Since we (android) can only perform breaking changes [effectively] every
six months, this will allow us to align with the [evolving] backend
significantly faster.

This also adds a test to ensure the conversion layer is updated whenever
any of these values are updated- since we no longer have the exhaustive
`when` to catch such cases.

---------

Co-authored-by: Rodrigo Lazo <rlazo@users.noreply.github.com>
  • Loading branch information
daymxn and rlazo authored Oct 4, 2024
1 parent f761b2c commit e417d5d
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 70 deletions.
1 change: 1 addition & 0 deletions firebase-vertexai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Unreleased
* [changed] Breaking Change: refactored enum classes to be normal classes (#6340).


# 16.0.0-beta05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ internal fun SafetySetting.toInternal() =
method.toInternal()
)

internal fun makeMissingCaseException(source: String, ordinal: Int): SerializationException {
return SerializationException(
"""
|Missing case for a $source: $ordinal
|This error indicates that one of the `toInternal` conversions needs updating.
|If you're a developer seeing this exception, please file an issue on our GitHub repo:
|https://github.com/firebase/firebase-android-sdk
"""
.trimMargin()
)
}

internal fun GenerationConfig.toInternal() =
com.google.firebase.vertexai.common.client.GenerationConfig(
temperature = temperature,
Expand All @@ -132,13 +144,15 @@ internal fun HarmCategory.toInternal() =
HarmCategory.DANGEROUS_CONTENT ->
com.google.firebase.vertexai.common.shared.HarmCategory.DANGEROUS_CONTENT
HarmCategory.UNKNOWN -> com.google.firebase.vertexai.common.shared.HarmCategory.UNKNOWN
else -> throw makeMissingCaseException("HarmCategory", ordinal)
}

internal fun HarmBlockMethod.toInternal() =
when (this) {
HarmBlockMethod.SEVERITY -> com.google.firebase.vertexai.common.shared.HarmBlockMethod.SEVERITY
HarmBlockMethod.PROBABILITY ->
com.google.firebase.vertexai.common.shared.HarmBlockMethod.PROBABILITY
else -> throw makeMissingCaseException("HarmBlockMethod", ordinal)
}

internal fun ToolConfig.toInternal() =
Expand Down Expand Up @@ -166,6 +180,7 @@ internal fun HarmBlockThreshold.toInternal() =
com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
HarmBlockThreshold.LOW_AND_ABOVE ->
com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
else -> throw makeMissingCaseException("HarmBlockThreshold", ordinal)
}

internal fun Tool.toInternal() =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,24 @@ internal constructor(
)

/** The reason for content finishing. */
public enum class FinishReason {
/** A new and not yet supported value. */
UNKNOWN,
public class FinishReason private constructor(public val name: String, public val ordinal: Int) {
public companion object {
/** A new and not yet supported value. */
@JvmField public val UNKNOWN: FinishReason = FinishReason("UNKNOWN", 0)

/** Model finished successfully and stopped. */
STOP,
/** Model finished successfully and stopped. */
@JvmField public val STOP: FinishReason = FinishReason("STOP", 1)

/** Model hit the token limit. */
MAX_TOKENS,
/** Model hit the token limit. */
@JvmField public val MAX_TOKENS: FinishReason = FinishReason("MAX_TOKENS", 2)

/** [SafetySetting] prevented the model from outputting content. */
SAFETY,
/** [SafetySetting] prevented the model from outputting content. */
@JvmField public val SAFETY: FinishReason = FinishReason("SAFETY", 3)

/** Model began looping. */
RECITATION,
/** Model began looping. */
@JvmField public val RECITATION: FinishReason = FinishReason("RECITATION", 4)

/** Model stopped for another reason. */
OTHER
/** Model stopped for another reason. */
@JvmField public val OTHER: FinishReason = FinishReason("OTHER", 5)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ package com.google.firebase.vertexai.type
* Specifies how the block method computes the score that will be compared against the
* [HarmBlockThreshold] in [SafetySetting].
*/
public enum class HarmBlockMethod {
/**
* The harm block method uses both probability and severity scores. See [HarmSeverity] and
* [HarmProbability].
*/
SEVERITY,
/** The harm block method uses the probability score. See [HarmProbability]. */
PROBABILITY,
public class HarmBlockMethod private constructor(public val ordinal: Int) {
public companion object {
/**
* The harm block method uses both probability and severity scores. See [HarmSeverity] and
* [HarmProbability].
*/
@JvmField public val SEVERITY: HarmBlockMethod = HarmBlockMethod(0)

/** The harm block method uses the probability score. See [HarmProbability]. */
@JvmField public val PROBABILITY: HarmBlockMethod = HarmBlockMethod(1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
package com.google.firebase.vertexai.type

/** Represents the threshold for a [HarmCategory] to be allowed by [SafetySetting]. */
public enum class HarmBlockThreshold {
/** Content with negligible harm is allowed. */
LOW_AND_ABOVE,
public class HarmBlockThreshold private constructor(public val ordinal: Int) {
public companion object {
/** Content with negligible harm is allowed. */
@JvmField public val LOW_AND_ABOVE: HarmBlockThreshold = HarmBlockThreshold(0)

/** Content with negligible to low harm is allowed. */
MEDIUM_AND_ABOVE,
/** Content with negligible to low harm is allowed. */
@JvmField public val MEDIUM_AND_ABOVE: HarmBlockThreshold = HarmBlockThreshold(1)

/** Content with negligible to medium harm is allowed. */
ONLY_HIGH,
/** Content with negligible to medium harm is allowed. */
@JvmField public val ONLY_HIGH: HarmBlockThreshold = HarmBlockThreshold(2)

/** All content is allowed regardless of harm. */
NONE
/** All content is allowed regardless of harm. */
@JvmField public val NONE: HarmBlockThreshold = HarmBlockThreshold(3)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
package com.google.firebase.vertexai.type

/** Category for a given harm rating. */
public enum class HarmCategory {
/** A new and not yet supported value. */
UNKNOWN,
public class HarmCategory private constructor(public val ordinal: Int) {
public companion object {
/** A new and not yet supported value. */
@JvmField public val UNKNOWN: HarmCategory = HarmCategory(0)

/** Harassment content. */
HARASSMENT,
/** Harassment content. */
@JvmField public val HARASSMENT: HarmCategory = HarmCategory(1)

/** Hate speech and content. */
HATE_SPEECH,
/** Hate speech and content. */
@JvmField public val HATE_SPEECH: HarmCategory = HarmCategory(2)

/** Sexually explicit content. */
SEXUALLY_EXPLICIT,
/** Sexually explicit content. */
@JvmField public val SEXUALLY_EXPLICIT: HarmCategory = HarmCategory(3)

/** Dangerous content. */
DANGEROUS_CONTENT
/** Dangerous content. */
@JvmField public val DANGEROUS_CONTENT: HarmCategory = HarmCategory(4)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
package com.google.firebase.vertexai.type

/** Represents the probability that some [HarmCategory] is applicable in a [SafetyRating]. */
public enum class HarmProbability {
/** A new and not yet supported value. */
UNKNOWN,
public class HarmProbability private constructor(public val ordinal: Int) {
public companion object {
/** A new and not yet supported value. */
@JvmField public val UNKNOWN: HarmProbability = HarmProbability(0)

/** Probability for harm is negligible. */
NEGLIGIBLE,
/** Probability for harm is negligible. */
@JvmField public val NEGLIGIBLE: HarmProbability = HarmProbability(1)

/** Probability for harm is low. */
LOW,
/** Probability for harm is low. */
@JvmField public val LOW: HarmProbability = HarmProbability(2)

/** Probability for harm is medium. */
MEDIUM,
/** Probability for harm is medium. */
@JvmField public val MEDIUM: HarmProbability = HarmProbability(3)

/** Probability for harm is high. */
HIGH,
/** Probability for harm is high. */
@JvmField public val HIGH: HarmProbability = HarmProbability(4)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
package com.google.firebase.vertexai.type

/** Represents the severity of a [HarmCategory] being applicable in a [SafetyRating]. */
public enum class HarmSeverity {
/** A new and not yet supported value. */
UNKNOWN,
public class HarmSeverity private constructor(public val ordinal: Int) {
public companion object {
/** A new and not yet supported value. */
@JvmField public val UNKNOWN: HarmSeverity = HarmSeverity(0)

/** Severity for harm is negligible. */
NEGLIGIBLE,
/** Severity for harm is negligible. */
@JvmField public val NEGLIGIBLE: HarmSeverity = HarmSeverity(1)

/** Low level of harm severity. */
LOW,
/** Low level of harm severity. */
@JvmField public val LOW: HarmSeverity = HarmSeverity(2)

/** Medium level of harm severity. */
MEDIUM,
/** Medium level of harm severity. */
@JvmField public val MEDIUM: HarmSeverity = HarmSeverity(3)

/** High level of harm severity. */
HIGH,
/** High level of harm severity. */
@JvmField public val HIGH: HarmSeverity = HarmSeverity(4)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ public class PromptFeedback(
)

/** Describes why content was blocked. */
public enum class BlockReason {
/** A new and not yet supported value. */
UNKNOWN,
public class BlockReason private constructor(public val name: String, public val ordinal: Int) {
public companion object {
/** A new and not yet supported value. */
@JvmField public val UNKNOWN: BlockReason = BlockReason("UNKNOWN", 0)

/** Content was blocked for violating provided [SafetySetting]. */
SAFETY,
/** Content was blocked for violating provided [SafetySetting]. */
@JvmField public val SAFETY: BlockReason = BlockReason("SAFETY", 1)

/** Content was blocked for another reason. */
OTHER
/** Content was blocked for another reason. */
@JvmField public val OTHER: BlockReason = BlockReason("OTHER", 2)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai.common

import com.google.firebase.vertexai.internal.util.toInternal
import com.google.firebase.vertexai.type.HarmBlockMethod
import com.google.firebase.vertexai.type.HarmBlockThreshold
import com.google.firebase.vertexai.type.HarmCategory
import org.junit.Test

/**
* Fetches all the `@JvmStatic` properties of a class that are instances of the class itself.
*
* For example, given the following class:
* ```kt
* public class HarmCategory private constructor(public val ordinal: Int) {
* public companion object {
* @JvmField public val UNKNOWN: HarmCategory = HarmCategory(0)
* @JvmField public val HARASSMENT: HarmCategory = HarmCategory(1)
* }
* }
* ```
* This function will yield:
* ```kt
* [UNKNOWN, HARASSMENT]
* ```
*/
internal inline fun <reified T : Any> getEnumValues(): List<T> {
return T::class
.java
.declaredFields
.filter { it.type == T::class.java }
.mapNotNull { it.get(null) as? T }
}

/**
* Ensures that whenever any of our "pseudo-enums" are updated, that the conversion layer is also
* updated.
*/
internal class EnumUpdateTests {
@Test
fun `HarmCategory#toInternal() covers all values`() {
val values = getEnumValues<HarmCategory>()
values.forEach { it.toInternal() }
}

@Test
fun `HarmBlockMethod#toInternal() covers all values`() {
val values = getEnumValues<HarmBlockMethod>()
values.forEach { it.toInternal() }
}

@Test
fun `HarmBlockThreshold#toInternal() covers all values`() {
val values = getEnumValues<HarmBlockThreshold>()
values.forEach { it.toInternal() }
}
}

0 comments on commit e417d5d

Please sign in to comment.