Skip to content

Commit

Permalink
Changes necessary to support protovalidate-java integration (#211)
Browse files Browse the repository at this point in the history
- Fix generation for map entries for maps of presence-tracked primitives in proto2, e.g. `map<int32, int32>`.
- Fix generation for maps with reserved names (e.g. `val`)
- Fix full type name generation for nested messages
- Add `KtProperty` annotation so that you can get a property's field number reflectively
  • Loading branch information
andrewparmet authored Dec 27, 2023
1 parent eeb79a6 commit a03694a
Show file tree
Hide file tree
Showing 15 changed files with 154 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ import com.squareup.kotlinpoet.withIndent
import protokt.v1.codegen.util.FieldType

fun deserializeVarInitialState(p: PropertyInfo) =
if (p.repeated || p.wrapped || p.nullable || p.fieldType == FieldType.Message) {
if (
(p.repeated || p.wrapped || p.nullable || p.fieldType == FieldType.Message) &&
!p.mapEntry
) {
CodeBlock.of("null")
} else {
p.defaultValue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,15 @@ private class MapEntryGenerator(
private val key = msg.fields[0] as StandardField
private val value = msg.fields[1] as StandardField

private val keyProp = constructorProperty("key", key.className, false)
private val valProp = constructorProperty("value", value.className, false)

fun generate() =
TypeSpec.classBuilder(msg.className).apply {
addModifiers(KModifier.PRIVATE)
superclass(AbstractKtMessage::class)
addProperty(constructorProperty("key", key.className, false))
addProperty(constructorProperty("value", value.className, false))
addProperty(keyProp)
addProperty(valProp)
addConstructor()
addMessageSize()
addSerialize()
Expand Down Expand Up @@ -92,8 +95,8 @@ private class MapEntryGenerator(
buildFunSpec("serialize") {
addModifiers(KModifier.OVERRIDE)
addParameter("serializer", KtMessageSerializer::class)
addStatement("%L", serialize(key, ctx))
addStatement("%L", serialize(value, ctx))
addStatement("%L", serialize(key, ctx, keyProp))
addStatement("%L", serialize(value, ctx, valProp))
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.squareup.kotlinpoet.asTypeName
import com.squareup.kotlinpoet.buildCodeBlock
import protokt.v1.AbstractKtMessage
import protokt.v1.KtGeneratedMessage
import protokt.v1.KtProperty
import protokt.v1.UnknownFieldSet
import protokt.v1.codegen.generate.CodeGenerator.Context
import protokt.v1.codegen.generate.CodeGenerator.generate
Expand All @@ -49,15 +50,16 @@ private class MessageGenerator(
generateMapEntry(msg, ctx)
} else {
val properties = annotateProperties(msg, ctx)
val propertySpecs = properties(properties)

TypeSpec.classBuilder(msg.className).apply {
annotateMessageDocumentation(ctx)?.let { addKdoc(formatDoc(it)) }
handleAnnotations()
handleConstructor(properties)
handleConstructor(propertySpecs)
addTypes(annotateOneofs(msg, ctx))
handleMessageSize()
addFunction(generateMessageSize(msg, ctx))
addFunction(generateSerializer(msg, ctx))
addFunction(generateMessageSize(msg, propertySpecs, ctx))
addFunction(generateSerializer(msg, propertySpecs, ctx))
handleEquals(properties)
handleHashCode(properties)
handleToString(properties)
Expand All @@ -80,21 +82,10 @@ private class MessageGenerator(
}

private fun TypeSpec.Builder.handleConstructor(
properties: List<PropertyInfo>
properties: List<PropertySpec>
) = apply {
superclass(AbstractKtMessage::class)
addProperties(
properties.map { property ->
PropertySpec.builder(property.name, property.propertyType).apply {
initializer(property.name)
if (property.overrides) {
addModifiers(KModifier.OVERRIDE)
}
property.documentation?.let { addKdoc(formatDoc(it)) }
handleDeprecation(property.deprecation)
}.build()
}
)
addProperties(properties)
addProperty(
PropertySpec.builder("unknownFields", UnknownFieldSet::class)
.initializer("unknownFields")
Expand All @@ -103,7 +94,7 @@ private class MessageGenerator(
primaryConstructor(
FunSpec.constructorBuilder()
.addModifiers(KModifier.PRIVATE)
.addParameters(properties.map { ParameterSpec(it.name, it.propertyType) })
.addParameters(properties.map { ParameterSpec(it.name, it.type) })
.addParameter(
ParameterSpec.builder("unknownFields", UnknownFieldSet::class)
.defaultValue("%T.empty()", UnknownFieldSet::class)
Expand All @@ -114,6 +105,25 @@ private class MessageGenerator(
handleSuperInterface(msg, ctx)
}

private fun properties(properties: List<PropertyInfo>) =
properties.map { property ->
PropertySpec.builder(property.name, property.propertyType).apply {
if (property.number != null) {
addAnnotation(
AnnotationSpec.builder(KtProperty::class)
.addMember("${property.number}")
.build()
)
}
initializer(property.name)
if (property.overrides) {
addModifiers(KModifier.OVERRIDE)
}
property.documentation?.let { addKdoc(formatDoc(it)) }
handleDeprecation(property.deprecation)
}.build()
}

private fun TypeSpec.Builder.handleMessageSize() =
addProperty(
PropertySpec.builder("messageSize", Int::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package protokt.v1.codegen.generate
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.buildCodeBlock
import protokt.v1.codegen.generate.CodeGenerator.Context
import protokt.v1.codegen.generate.Nullability.hasNonNullOption
Expand All @@ -31,11 +32,12 @@ import protokt.v1.codegen.util.Message
import protokt.v1.codegen.util.Oneof
import protokt.v1.codegen.util.StandardField

fun generateMessageSize(msg: Message, ctx: Context) =
MessageSizeGenerator(msg, ctx).generate()
fun generateMessageSize(msg: Message, properties: List<PropertySpec>, ctx: Context) =
MessageSizeGenerator(msg, properties, ctx).generate()

private class MessageSizeGenerator(
private val msg: Message,
private val properties: List<PropertySpec>,
private val ctx: Context
) {
private val resultVarName =
Expand All @@ -51,9 +53,10 @@ private class MessageSizeGenerator(
val fieldSizes =
msg.mapFields(
ctx,
properties,
false,
{ CodeBlock.of("$resultVarName·+=·%L", sizeOf(it, ctx)) },
{ oneof, std -> sizeofOneof(oneof, std) },
{ std, _ -> CodeBlock.of("$resultVarName·+=·%L", sizeOf(std, ctx)) },
{ oneof, std, _ -> sizeofOneof(oneof, std) },
{
if (it.hasNonNullOption) {
add("$resultVarName·+=·")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeName
import com.squareup.kotlinpoet.TypeSpec
import protokt.v1.KtProperty
import protokt.v1.codegen.generate.CodeGenerator.Context
import protokt.v1.codegen.generate.Deprecation.renderOptions
import protokt.v1.codegen.generate.Implements.handleSuperInterface
Expand Down Expand Up @@ -73,6 +74,11 @@ private class OneofGenerator(
}
.addProperty(
PropertySpec.builder(v.fieldName, v.type)
.addAnnotation(
AnnotationSpec.builder(KtProperty::class)
.addMember("${v.number}")
.build()
)
.initializer(v.fieldName)
.build()
)
Expand All @@ -96,6 +102,7 @@ private class OneofGenerator(
private fun info(f: StandardField) =
OneofGeneratorInfo(
fieldName = f.fieldName,
number = f.number,
type =
if (f.wrapped) {
interceptTypeName(f, ctx) ?: f.className
Expand All @@ -119,6 +126,7 @@ private class OneofGenerator(
class OneofGeneratorInfo(
val fieldName: String,
val type: TypeName,
val number: Int,
val documentation: List<String>?,
val deprecation: Deprecation.RenderOptions?
)
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ private class PropertyAnnotator(
val wrapperRequiresNullability = field.wrapperRequiresNullability(ctx)
PropertyInfo(
name = field.fieldName,
number = field.number,
propertyType = propertyType(field, type, wrapperRequiresNullability),
deserializeType = deserializeType(field, type),
builderPropertyType = dslPropertyType(field, type),
defaultValue = field.defaultValue(ctx),
defaultValue = field.defaultValue(ctx, msg.mapEntry),
fieldType = field.type,
repeated = field.repeated,
map = field.map,
mapEntry = msg.mapEntry,
nullable = field.nullable || field.optional || wrapperRequiresNullability,
nonNullOption = field.hasNonNullOption,
overrides = field.overrides(ctx, msg),
Expand All @@ -81,7 +83,7 @@ private class PropertyAnnotator(
propertyType = propertyType(field),
deserializeType = field.className.copy(nullable = true),
builderPropertyType = field.className.copy(nullable = true),
defaultValue = field.defaultValue(ctx),
defaultValue = field.defaultValue(ctx, false),
oneof = true,
nullable = field.nullable,
nonNullOption = field.hasNonNullOption,
Expand Down Expand Up @@ -128,7 +130,7 @@ private class PropertyAnnotator(
val vType: TypeName
)

private fun Field.defaultValue(ctx: Context) =
private fun Field.defaultValue(ctx: Context, mapEntry: Boolean) =
when (this) {
is StandardField ->
interceptDefaultValue(
Expand All @@ -138,7 +140,7 @@ private class PropertyAnnotator(
repeated -> CodeBlock.of("emptyList()")
type == FieldType.Message -> CodeBlock.of("null")
type == FieldType.Enum -> CodeBlock.of("%T.from(0)", className)
nullable -> CodeBlock.of("null")
nullable && !mapEntry -> CodeBlock.of("null")
else -> type.defaultValue
},
ctx
Expand All @@ -149,12 +151,14 @@ private class PropertyAnnotator(

class PropertyInfo(
val name: String,
val number: Int? = null,
val propertyType: TypeName,
val deserializeType: TypeName,
val builderPropertyType: TypeName,
val defaultValue: CodeBlock,
val nullable: Boolean,
val nonNullOption: Boolean,
val mapEntry: Boolean = false,
val fieldType: FieldType? = null,
val repeated: Boolean = false,
val map: Boolean = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package protokt.v1.codegen.generate

import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.buildCodeBlock
import protokt.v1.codegen.generate.CodeGenerator.Context
import protokt.v1.codegen.generate.Nullability.hasNonNullOption
Expand All @@ -29,19 +30,21 @@ import protokt.v1.codegen.util.StandardField

fun Message.mapFields(
ctx: Context,
properties: List<PropertySpec>,
skipConditionalForUnpackedRepeatedFields: Boolean,
std: (StandardField) -> CodeBlock,
oneof: (Oneof, StandardField) -> CodeBlock,
std: (StandardField, PropertySpec) -> CodeBlock,
oneof: (Oneof, StandardField, PropertySpec) -> CodeBlock,
oneofPreControlFlow: CodeBlock.Builder.(Oneof) -> Unit = {}
): List<CodeBlock> =
fields.map { field ->
when (field) {
is StandardField ->
standardFieldExecution(ctx, field, skipConditionalForUnpackedRepeatedFields) { std(field) }
is Oneof ->
oneofFieldExecution(field, { oneof(field, it) }, oneofPreControlFlow)
fields.zip(properties)
.map { (field, property) ->
when (field) {
is StandardField ->
standardFieldExecution(ctx, field, skipConditionalForUnpackedRepeatedFields) { std(field, property) }
is Oneof ->
oneofFieldExecution(field, { oneof(field, it, property) }, oneofPreControlFlow)
}
}
}

private fun standardFieldExecution(
ctx: Context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package protokt.v1.codegen.generate
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.buildCodeBlock
import protokt.v1.KtMessageSerializer
import protokt.v1.codegen.generate.CodeGenerator.Context
Expand All @@ -29,20 +30,22 @@ import protokt.v1.codegen.util.Message
import protokt.v1.codegen.util.Oneof
import protokt.v1.codegen.util.StandardField

fun generateSerializer(msg: Message, ctx: Context) =
SerializerGenerator(msg, ctx).generate()
fun generateSerializer(msg: Message, properties: List<PropertySpec>, ctx: Context) =
SerializerGenerator(msg, properties, ctx).generate()

private class SerializerGenerator(
private val msg: Message,
private val properties: List<PropertySpec>,
private val ctx: Context
) {
fun generate(): FunSpec {
val fieldSerializations =
msg.mapFields(
ctx,
properties,
true,
{ serialize(it, ctx) },
{ oneof, std -> serialize(std, ctx, oneof) }
{ f, p -> serialize(f, ctx, p) },
{ oneof, std, p -> serialize(std, ctx, p, oneof) }
)

return buildFunSpec("serialize") {
Expand All @@ -57,14 +60,15 @@ private class SerializerGenerator(
fun serialize(
f: StandardField,
ctx: Context,
p: PropertySpec,
o: Oneof? = null
): CodeBlock {
val fieldAccess =
if (o == null) {
interceptValueAccess(
f,
ctx,
if (f.repeated) { CodeBlock.of("it") } else { CodeBlock.of("%N", f.fieldName) }
if (f.repeated) { CodeBlock.of("it") } else { CodeBlock.of("%N", p) }
)
} else {
interceptValueAccess(f, ctx, CodeBlock.of("%N.%N", o.fieldName, f.fieldName))
Expand All @@ -80,10 +84,10 @@ fun serialize(
"elementsSize" to f.elementsSize()
)
)
add("%N.forEach·{·serializer.%L·}", f.fieldName, f.write(CodeBlock.of("it")))
add("%N.forEach·{·serializer.%L·}", p, f.write(CodeBlock.of("it")))
}
f.map -> buildCodeBlock {
beginControlFlow("${f.fieldName}.entries.forEach")
beginControlFlow("%N.entries.forEach", p)
add(
"serializer.writeTag(${f.tag.value}u).write(%L)\n",
f.boxMap(ctx)
Expand All @@ -95,7 +99,7 @@ fun serialize(
"%name:N.forEach·{·" +
"serializer.writeTag(${f.tag.value}u).%write:L·}",
mapOf(
"name" to f.fieldName,
"name" to p,
"write" to f.write(fieldAccess)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ class MessageParser(
private val enclosingMessages: List<String>
) {
fun toMessage(): Message {
val typeName = desc.name
val fieldList = FieldParser(ctx, desc, enclosingMessages).toFields()
val simpleNames = enclosingMessages + typeName
val simpleNames = enclosingMessages + desc.name
return Message(
fields = fieldList.sortedBy {
when (it) {
Expand All @@ -48,7 +47,7 @@ class MessageParser(
desc.options.getExtension(ProtoktProtos.class_)
),
index = idx,
fullProtobufTypeName = "${ctx.fdp.`package`}.$typeName",
fullProtobufTypeName = "${ctx.fdp.`package`}.${simpleNames.joinToString(".")}",
className = ctx.className(simpleNames),
deserializerClassName = ctx.className(simpleNames + DESERIALIZER)
)
Expand Down
Loading

0 comments on commit a03694a

Please sign in to comment.