diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/AnnotationDesugar.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/AnnotationDesugar.java index 79b4af59f064..02d4802ebce4 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/AnnotationDesugar.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/AnnotationDesugar.java @@ -32,6 +32,7 @@ import org.ballerinalang.model.tree.TopLevelNode; import org.ballerinalang.model.tree.expressions.RecordLiteralNode; import org.ballerinalang.model.types.TypeKind; +import org.wso2.ballerinalang.compiler.parser.NodeCloner; import org.wso2.ballerinalang.compiler.semantics.analyzer.ConstantValueResolver; import org.wso2.ballerinalang.compiler.semantics.analyzer.SymbolResolver; import org.wso2.ballerinalang.compiler.semantics.analyzer.Types; @@ -140,6 +141,8 @@ public class AnnotationDesugar { private final ConstantValueResolver constantValueResolver; private final ClosureGenerator closureGenerator; + private final NodeCloner nodeCloner; + public static AnnotationDesugar getInstance(CompilerContext context) { AnnotationDesugar annotationDesugar = context.get(ANNOTATION_DESUGAR_KEY); if (annotationDesugar == null) { @@ -157,6 +160,7 @@ private AnnotationDesugar(CompilerContext context) { this.symResolver = SymbolResolver.getInstance(context); this.constantValueResolver = ConstantValueResolver.getInstance(context); this.closureGenerator = ClosureGenerator.getInstance(context); + this.nodeCloner = NodeCloner.getInstance(context); } /** @@ -483,6 +487,22 @@ private void defineTypeAnnotations(BLangPackage pkgNode, SymbolEnv env, BLangFun private void defineFunctionAnnotations(BLangPackage pkgNode, SymbolEnv env, BLangFunction initFunction) { BLangBlockFunctionBody initFnBody = (BLangBlockFunctionBody) initFunction.body; BLangFunction[] functions = pkgNode.functions.toArray(new BLangFunction[pkgNode.functions.size()]); + // TODO: Find a better way to get the annotation map init statement + Optional globalAnnotInitStmt = initFnBody.stmts.stream() + .filter(statement -> statement.getKind() == NodeKind.ASSIGNMENT && ((BLangAssignment) statement) + .varRef.getKind() == NodeKind.SIMPLE_VARIABLE_REF) + .filter(statement -> ((BLangSimpleVarRef) ((BLangAssignment) statement).varRef).symbol.name.value. + equals(ANNOTATION_DATA)).findFirst(); + + if (globalAnnotInitStmt.isPresent()) { + BLangAssignment stmt = (BLangAssignment) globalAnnotInitStmt.get(); + int index = initFnBody.stmts.indexOf(stmt); + BLangAssignment clonedStmt = nodeCloner.cloneNode(stmt); + ((BLangSimpleVarRef) clonedStmt.varRef).symbol = ((BLangSimpleVarRef) stmt.varRef).symbol; + initFnBody.stmts.add(index + 1, clonedStmt); + globalAnnotInitStmt = Optional.of(clonedStmt); + } + for (BLangFunction function : functions) { PackageID pkgID = function.symbol.pkgID; BSymbol owner = function.symbol.owner; @@ -505,9 +525,13 @@ private void defineFunctionAnnotations(BLangPackage pkgNode, SymbolEnv env, BLan if (function.attachedFunction && Symbols.isFlagOn(function.receiver.getBType().flags, Flags.OBJECT_CTOR)) { addLambdaToGlobalAnnotMap(identifier, lambdaFunction, target); + globalAnnotInitStmt.ifPresent(statement -> updateInitOfAnnotGlobalMapWithLambda(statement, + identifier, lambdaFunction)); index = calculateIndex(initFnBody.stmts, function.receiver.getBType().tsymbol); } else { addInvocationToGlobalAnnotMap(identifier, lambdaFunction, target); + globalAnnotInitStmt.ifPresent(statement -> updateInitOfAnnotGlobalMapWithInvocation(statement, + identifier, lambdaFunction)); index = initFnBody.stmts.size(); } @@ -519,6 +543,20 @@ private void defineFunctionAnnotations(BLangPackage pkgNode, SymbolEnv env, BLan } } + private void updateInitOfAnnotGlobalMapWithInvocation(BLangStatement globalAnnotInitStmt, String identifier, + BLangLambdaFunction expression) { + BLangAssignment annotMapInitStmt = (BLangAssignment) globalAnnotInitStmt; + BLangRecordLiteral mapLiteral = (BLangRecordLiteral) annotMapInitStmt.expr; + addInvocationToLiteral(mapLiteral, identifier, expression.pos, expression); + } + + private void updateInitOfAnnotGlobalMapWithLambda(BLangStatement globalAnnotInitStmt, String identifier, + BLangLambdaFunction expression) { + BLangAssignment annotMapInitStmt = (BLangAssignment) globalAnnotInitStmt; + BLangRecordLiteral mapLiteral = (BLangRecordLiteral) annotMapInitStmt.expr; + addLambdaToLiteral(mapLiteral, identifier, expression.pos, expression); + } + private void attachSchedulerPolicy(BLangFunction function) { for (BLangAnnotationAttachment annotation : function.annAttachments) { if (!annotation.annotationName.value.equals("strand")) { @@ -986,6 +1024,13 @@ private void addInvocationToLiteral(BLangRecordLiteral recordLiteral, String ide ASTBuilderUtil.createLiteral(pos, symTable.stringType, identifier), annotFuncInvocation)); } + private void addLambdaToLiteral(BLangRecordLiteral recordLiteral, String identifier, + Location pos, BLangLambdaFunction lambdaFunction) { + recordLiteral.fields.add(ASTBuilderUtil.createBLangRecordKeyValue( + ASTBuilderUtil.createLiteral(pos, symTable.stringType, identifier), + ASTBuilderUtil.createVariableRef(lambdaFunction.pos, lambdaFunction.function.symbol))); + } + private void addInvocationToLiteral(BLangRecordLiteral recordLiteral, String identifier, Location pos, BInvokableSymbol invokableSymbol) { BLangInvocation annotFuncInvocation = getInvocation(invokableSymbol);