Skip to content

Commit

Permalink
[kotlin] Fix nested types input/output code generation (#273)
Browse files Browse the repository at this point in the history
* [kotlin] Fix nested types input/output code generation
* Fix nested classes issue
  • Loading branch information
slinkydeveloper authored Jun 3, 2024
1 parent 0e1434a commit 8f90998
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 6 deletions.
2 changes: 2 additions & 0 deletions sdk-api-kotlin-gen/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
plugins {
java
kotlin("jvm")
kotlin("plugin.serialization")
`library-publishing-conventions`
alias(kotlinLibs.plugins.ksp)
}
Expand All @@ -22,6 +23,7 @@ dependencies {
testImplementation(coreLibs.protobuf.java)
testImplementation(coreLibs.log4j.core)
testImplementation(kotlinLibs.kotlinx.coroutines)
testImplementation(kotlinLibs.kotlinx.serialization.core)

// Import test suites from sdk-core
testImplementation(project(":sdk-core", "testArchive"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,13 @@ class KElementConverter(
if (rawAnnotation != null && ty != byteArrayType) {
logger.error("A parameter annotated with @Raw MUST be of type byte[], was $ty", relatedNode)
}
if (ty.isFunctionType || ty.isSuspendFunctionType) {
logger.error("Cannot use fun as parameter or return type", relatedNode)
}

var serdeDecl: String = if (rawAnnotation != null) RAW_SERDE else jsonSerdeDecl(ty)
val qualifiedTypeName = qualifiedTypeName(ty)
var serdeDecl: String =
if (rawAnnotation != null) RAW_SERDE else jsonSerdeDecl(ty, qualifiedTypeName)
if (rawAnnotation != null &&
rawAnnotation.contentType != getAnnotationDefaultValue(Raw::class.java, "contentType")) {
serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, rawAnnotation.contentType)
Expand All @@ -242,7 +247,7 @@ class KElementConverter(
serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, jsonAnnotation.contentType)
}

return PayloadType(false, ty.toString(), boxedType(ty), serdeDecl)
return PayloadType(false, qualifiedTypeName, boxedType(ty, qualifiedTypeName), serdeDecl)
}

private fun contentTypeDecoratedSerdeDecl(serdeDecl: String, contentType: String): String {
Expand Down Expand Up @@ -311,17 +316,40 @@ class KElementConverter(
}
}

private fun jsonSerdeDecl(ty: KSType): String {
private fun jsonSerdeDecl(ty: KSType, qualifiedTypeName: String): String {
return when (ty) {
builtIns.unitType -> "dev.restate.sdk.kotlin.KtSerdes.UNIT"
else -> "dev.restate.sdk.kotlin.KtSerdes.json<${boxedType(ty)}>()"
else -> "dev.restate.sdk.kotlin.KtSerdes.json<${boxedType(ty, qualifiedTypeName)}>()"
}
}

private fun boxedType(ty: KSType): String {
private fun boxedType(ty: KSType, qualifiedTypeName: String): String {
return when (ty) {
builtIns.unitType -> "Unit"
else -> ty.toString()
else -> qualifiedTypeName
}
}

private fun qualifiedTypeName(ksType: KSType): String {
var typeName = ksType.declaration.qualifiedName?.asString() ?: ksType.toString()

if (ksType.arguments.isNotEmpty()) {
typeName =
"$typeName<${
ksType.arguments.joinToString(separator = ", ") {
if (it.variance == Variance.STAR) {
it.variance.label
} else {
"${it.variance.label} ${qualifiedTypeName(it.type!!.resolve())}"
}
}
}>"
}

if (ksType.isMarkedNullable) {
typeName = "$typeName?"
}

return typeName
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import dev.restate.sdk.core.TestDefinitions.TestDefinition
import dev.restate.sdk.core.TestDefinitions.testInvocation
import dev.restate.sdk.core.TestSerdes
import java.util.stream.Stream
import kotlinx.serialization.Serializable

class CodegenTest : TestDefinitions.TestSuite {
@Service
Expand All @@ -42,6 +43,26 @@ class CodegenTest : TestDefinitions.TestSuite {
}
}

@VirtualObject
class NestedDataClass {
@Serializable data class Input(val a: String)

@Serializable data class Output(val a: String)

@Exclusive
suspend fun greet(context: ObjectContext, request: Input): Output {
return Output(request.a)
}

@Exclusive
suspend fun complexType(
context: ObjectContext,
request: Map<Output, List<out Input>>
): Map<Input, List<out Output>> {
return mapOf()
}
}

@VirtualObject
interface GreeterInterface {
@Exclusive suspend fun greet(context: ObjectContext, request: String): String
Expand Down Expand Up @@ -194,6 +215,13 @@ class CodegenTest : TestDefinitions.TestSuite {
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation({ NestedDataClass() }, "greet")
.withInput(
startMessage(1, "slinkydeveloper"),
inputMessage(KtSerdes.json(), NestedDataClass.Input("123")))
.onlyUnbuffered()
.expectingOutput(
outputMessage(KtSerdes.json(), NestedDataClass.Output("123")), END_MESSAGE),
testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
Expand Down

0 comments on commit 8f90998

Please sign in to comment.