diff --git a/sdk-api-kotlin-gen/build.gradle.kts b/sdk-api-kotlin-gen/build.gradle.kts index 94b3f3a3..feb0bf08 100644 --- a/sdk-api-kotlin-gen/build.gradle.kts +++ b/sdk-api-kotlin-gen/build.gradle.kts @@ -1,6 +1,7 @@ plugins { java kotlin("jvm") + kotlin("plugin.serialization") `library-publishing-conventions` alias(kotlinLibs.plugins.ksp) } @@ -22,6 +23,7 @@ dependencies { testImplementation(coreLibs.protobuf.java) testImplementation(coreLibs.log4j.core) testImplementation(kotlinLibs.kotlinx.coroutines) + testImplementation(kotlinLibs.kotlinx.serialization.core) // Import test suites from sdk-core testImplementation(project(":sdk-core", "testArchive")) diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt index 411c3f0b..22b18688 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt @@ -231,8 +231,13 @@ class KElementConverter( if (rawAnnotation != null && ty != byteArrayType) { logger.error("A parameter annotated with @Raw MUST be of type byte[], was $ty", relatedNode) } + if (ty.isFunctionType || ty.isSuspendFunctionType) { + logger.error("Cannot use fun as parameter or return type", relatedNode) + } - var serdeDecl: String = if (rawAnnotation != null) RAW_SERDE else jsonSerdeDecl(ty) + val qualifiedTypeName = qualifiedTypeName(ty) + var serdeDecl: String = + if (rawAnnotation != null) RAW_SERDE else jsonSerdeDecl(ty, qualifiedTypeName) if (rawAnnotation != null && rawAnnotation.contentType != getAnnotationDefaultValue(Raw::class.java, "contentType")) { serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, rawAnnotation.contentType) @@ -242,7 +247,7 @@ class KElementConverter( serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, jsonAnnotation.contentType) } - return PayloadType(false, ty.toString(), boxedType(ty), serdeDecl) + return PayloadType(false, qualifiedTypeName, boxedType(ty, qualifiedTypeName), serdeDecl) } private fun contentTypeDecoratedSerdeDecl(serdeDecl: String, contentType: String): String { @@ -311,17 +316,40 @@ class KElementConverter( } } - private fun jsonSerdeDecl(ty: KSType): String { + private fun jsonSerdeDecl(ty: KSType, qualifiedTypeName: String): String { return when (ty) { builtIns.unitType -> "dev.restate.sdk.kotlin.KtSerdes.UNIT" - else -> "dev.restate.sdk.kotlin.KtSerdes.json<${boxedType(ty)}>()" + else -> "dev.restate.sdk.kotlin.KtSerdes.json<${boxedType(ty, qualifiedTypeName)}>()" } } - private fun boxedType(ty: KSType): String { + private fun boxedType(ty: KSType, qualifiedTypeName: String): String { return when (ty) { builtIns.unitType -> "Unit" - else -> ty.toString() + else -> qualifiedTypeName + } + } + + private fun qualifiedTypeName(ksType: KSType): String { + var typeName = ksType.declaration.qualifiedName?.asString() ?: ksType.toString() + + if (ksType.arguments.isNotEmpty()) { + typeName = + "$typeName<${ + ksType.arguments.joinToString(separator = ", ") { + if (it.variance == Variance.STAR) { + it.variance.label + } else { + "${it.variance.label} ${qualifiedTypeName(it.type!!.resolve())}" + } + } + }>" } + + if (ksType.isMarkedNullable) { + typeName = "$typeName?" + } + + return typeName } } diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt index 4ad5dc89..82dea76a 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt +++ b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt @@ -18,6 +18,7 @@ import dev.restate.sdk.core.TestDefinitions.TestDefinition import dev.restate.sdk.core.TestDefinitions.testInvocation import dev.restate.sdk.core.TestSerdes import java.util.stream.Stream +import kotlinx.serialization.Serializable class CodegenTest : TestDefinitions.TestSuite { @Service @@ -42,6 +43,26 @@ class CodegenTest : TestDefinitions.TestSuite { } } + @VirtualObject + class NestedDataClass { + @Serializable data class Input(val a: String) + + @Serializable data class Output(val a: String) + + @Exclusive + suspend fun greet(context: ObjectContext, request: Input): Output { + return Output(request.a) + } + + @Exclusive + suspend fun complexType( + context: ObjectContext, + request: Map> + ): Map> { + return mapOf() + } + } + @VirtualObject interface GreeterInterface { @Exclusive suspend fun greet(context: ObjectContext, request: String): String @@ -194,6 +215,13 @@ class CodegenTest : TestDefinitions.TestSuite { .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) .onlyUnbuffered() .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + testInvocation({ NestedDataClass() }, "greet") + .withInput( + startMessage(1, "slinkydeveloper"), + inputMessage(KtSerdes.json(), NestedDataClass.Input("123"))) + .onlyUnbuffered() + .expectingOutput( + outputMessage(KtSerdes.json(), NestedDataClass.Output("123")), END_MESSAGE), testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet") .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) .onlyUnbuffered()