diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateTest.java index 4760868cf5c..733385c6cb5 100755 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/JavaTemplateTest.java @@ -22,6 +22,7 @@ import org.openrewrite.Issue; import org.openrewrite.java.tree.*; import org.openrewrite.test.RewriteTest; +import org.openrewrite.test.TypeValidation; import java.util.List; @@ -297,7 +298,7 @@ class T { void m() { hashCode(); } - + void m2() { hashCode(); } @@ -720,20 +721,20 @@ public J visitNewClass(J.NewClass newClass, ExecutionContext ctx) { class A { public enum Type { One; - + public Type(String t) { } - + String t; - + public static Type fromType(String type) { return null; } } - + public A(Type type) {} public A() {} - + public void method(Type type) { new A(type); } @@ -743,20 +744,20 @@ public void method(Type type) { class A { public enum Type { One; - + public Type(String t) { } - + String t; - + public static Type fromType(String type) { return null; } } - + public A(Type type) {} public A() {} - + public void method(Type type) { new A(); } @@ -864,7 +865,7 @@ public J visitBinary(J.Binary binary, ExecutionContext ctx) { java( """ import java.util.Collection; - + class Test { void doSomething(Collection c) { assert c.size() > 0; @@ -873,7 +874,7 @@ void doSomething(Collection c) { """, """ import java.util.Collection; - + class Test { void doSomething(Collection c) { assert !c.isEmpty(); @@ -1083,7 +1084,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu """ import java.util.Map; import org.junit.jupiter.api.Assertions; - + class T { void m(String one, Map map) { Assertions.assertEquals(one, map.get("one")); @@ -1092,9 +1093,9 @@ void m(String one, Map map) { """, """ import java.util.Map; - + import static org.assertj.core.api.Assertions.assertThat; - + class T { void m(String one, Map map) { assertThat(map.get("one")).isEqualTo(one); @@ -1139,7 +1140,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu import java.util.Objects; import java.util.Map; import java.util.HashMap; - + class T { void m() { Map map = new HashMap<>(); @@ -1150,10 +1151,10 @@ void m() { """ import java.util.Objects; import java.util.Map; - + import static java.util.Objects.requireNonNull; import java.util.HashMap; - + class T { void m() { Map map = new HashMap<>(); @@ -1181,13 +1182,13 @@ public J visitVariableDeclarations(J.VariableDeclarations multiVariable, Executi java( """ interface Test { - + String a; } """, """ interface Test { - + String a(); } """ @@ -1200,19 +1201,72 @@ void finalMethodParameter() { rewriteRun( spec -> spec.recipe(new ReplaceAnnotation("@org.jetbrains.annotations.NotNull", "@lombok.NonNull", null)), java( - """ - import org.jetbrains.annotations.NotNull; - - class A { - String testMethod(@NotNull final String test) {} - } - """, """ - import lombok.NonNull; - - class A { - String testMethod(@NonNull final String test) {} + """ + import org.jetbrains.annotations.NotNull; + + class A { + String testMethod(@NotNull final String test) {} + } + """, """ + import lombok.NonNull; + + class A { + String testMethod(@NonNull final String test) {} + } + """) + ); + } + + @Test + @Issue("https://github.com/openrewrite/rewrite-spring/pull/284") + void replaceMethodInChainFollowedByGenericTypeParameters() { + rewriteRun( + spec -> spec + .recipe(toRecipe(() -> new JavaVisitor<>() { + @Override + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + if (new MethodMatcher("batch.StepBuilder create()").matches(method)) { + return JavaTemplate.builder("new StepBuilder()") + //.doBeforeParseTemplate(System.out::println) + .contextSensitive() + .build() + .apply(getCursor(), method.getCoordinates().replace()); } - """) + return super.visitMethodInvocation(method, ctx); + } + })) + .afterTypeValidationOptions(TypeValidation.builder().constructorInvocations(false).build()) // Unclear why + .parser(JavaParser.fromJavaVersion().dependsOn( + """ + package batch; + public class StepBuilder { + public static StepBuilder create() { return new StepBuilder(); } + public StepBuilder() {} + public T method() { return null; } + } + """ + ) + ), + java( + """ + import batch.StepBuilder; + class Foo { + void test() { + StepBuilder.create() + .method(); + } + } + """, + """ + import batch.StepBuilder; + class Foo { + void test() { + new StepBuilder() + .method(); + } + } + """ + ) ); } } diff --git a/rewrite-java/src/main/java/org/openrewrite/java/TreeVisitingPrinter.java b/rewrite-java/src/main/java/org/openrewrite/java/TreeVisitingPrinter.java index 325cd9f1926..7f086a0522e 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/TreeVisitingPrinter.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/TreeVisitingPrinter.java @@ -17,14 +17,12 @@ import org.openrewrite.*; import org.openrewrite.internal.lang.Nullable; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JLeftPadded; -import org.openrewrite.java.tree.JRightPadded; -import org.openrewrite.java.tree.Space; +import org.openrewrite.java.tree.*; import java.util.*; import java.util.stream.Collectors; +import static java.util.stream.Collectors.joining; import static java.util.stream.StreamSupport.stream; @@ -174,7 +172,8 @@ private static String printTreeElement(Tree tree) { return s != null ? s : ""; } - String[] lines = tree.toString().split("\n"); + String precedingComments = tree instanceof J ? printComments(((J) tree).getPrefix().getComments()) : ""; + String[] lines = (precedingComments + tree).split("\n"); StringBuilder output = new StringBuilder(); for (int i = 0; i < lines.length; i++) { output.append(lines[i].trim()); @@ -197,11 +196,17 @@ private static String printSpace(Space space) { sb.append(" whitespace=\"") .append(space.getWhitespace()).append("\""); sb.append(" comments=\"") - .append(String.join(",", space.getComments().stream().map(c -> c.printComment(new Cursor(null, "root"))).collect(Collectors.toList()))) - .append("\"");; + .append(printComments(space.getComments())) + .append("\""); return sb.toString().replace("\n", "\\s\n"); } + private static String printComments(List comments) { + return comments.stream() + .map(c -> c.printComment(new Cursor(null, "root"))) + .collect(joining(",")); + } + @Override public @Nullable Tree visit(@Nullable Tree tree, ExecutionContext ctx) { if (tree == null) { diff --git a/rewrite-java/src/main/java/org/openrewrite/java/internal/template/BlockStatementTemplateGenerator.java b/rewrite-java/src/main/java/org/openrewrite/java/internal/template/BlockStatementTemplateGenerator.java index 1711a302563..902127bf612 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/internal/template/BlockStatementTemplateGenerator.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/internal/template/BlockStatementTemplateGenerator.java @@ -67,9 +67,17 @@ public String template(Cursor cursor, String template, Space.Location location, after.append('}'); } - template(next(cursor), cursor.getValue(), before, after, cursor.getValue(), mode); + if (contextSensitive) { + contextTemplate(next(cursor), cursor.getValue(), before, after, cursor.getValue(), mode); + } else { + contextFreeTemplate(next(cursor), cursor.getValue(), before, after); + } - return before.toString().trim() + "\n/*" + TEMPLATE_COMMENT + "*/" + template + "/*" + STOP_COMMENT + "*/" + "\n" + after.toString().trim(); + return before.toString().trim() + + "\n/*" + TEMPLATE_COMMENT + "*/" + + template + + "/*" + STOP_COMMENT + "*/\n" + + after.toString().trim(); }); } @@ -84,7 +92,8 @@ public List listTemplatedTrees(JavaSourceFile cu, Class e @Override public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, Integer integer) { - if (getCursor().getParentTreeCursor().getValue() instanceof SourceFile && (classDecl.getSimpleName().equals("__P__") || classDecl.getSimpleName().equals("__M__"))) { + if (getCursor().getParentTreeCursor().getValue() instanceof SourceFile && + (classDecl.getSimpleName().equals("__P__") || classDecl.getSimpleName().equals("__M__"))) { // don't visit the __P__ and __M__ classes declaring stubs return classDecl; } @@ -193,14 +202,6 @@ private boolean isTemplateStopComment(Comment comment) { return js; } - private void template(Cursor cursor, J prior, StringBuilder before, StringBuilder after, J insertionPoint, JavaCoordinates.Mode mode) { - if (contextSensitive) { - contextTemplate(cursor, prior, before, after, insertionPoint, mode); - } else { - contextFreeTemplate(cursor, prior, before, after); - } - } - @SuppressWarnings("DataFlowIssue") protected void contextFreeTemplate(Cursor cursor, J j, StringBuilder before, StringBuilder after) { if (j instanceof J.Lambda) { @@ -774,11 +775,13 @@ private static class RemoveTreeMarker implements Marker { private static class TemplatedTreeTrimmerVisitor extends JavaVisitor { private boolean stopCommentExists(@Nullable J j) { - if (j != null) { - for (Comment comment : j.getComments()) { - if (comment instanceof TextComment && ((TextComment) comment).getText().equals(STOP_COMMENT)) { - return true; - } + return j != null && stopCommentExists(j.getComments()); + } + + private static boolean stopCommentExists(List comments) { + for (Comment comment : comments) { + if (comment instanceof TextComment && ((TextComment) comment).getText().equals(STOP_COMMENT)) { + return true; } } return false; @@ -803,6 +806,13 @@ public J visitMethodInvocation(J.MethodInvocation method, Integer integer) { //noinspection ConstantConditions return mi.getSelect(); } + if (method.getTypeParameters() != null) { + // For method chains return `select` if `STOP_COMMENT` is found before `typeParameters` + if (stopCommentExists(mi.getPadding().getTypeParameters().getBefore().getComments())) { + //noinspection ConstantConditions + return mi.getSelect(); + } + } return mi; } }