From 3412efbd62dd4d155dd633b1674e54a1cdd023db Mon Sep 17 00:00:00 2001
From: SMIT MALKAN <smitmalkan99@gmail.com>
Date: Sun, 18 Aug 2024 01:32:27 +0530
Subject: [PATCH] Collapse consecutive `assertThat` statements (#373) (#392)

* added recipe to collapse consecutive AssertThat statements (#373)

* Apply formatter and remove unused builder

* fixed formatting (#373)

* Fix test indentation and type issues

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

* Strip out unnecessary elements from tests

* Remove test with compilation error

* First round of conventions applied

* Apply quick old suggestion

* Showcase an issue with incorrect use of `extracting`

* Alternative implementation without index or nested visitors just yet

* Compare types for the last unit test to pass

* Further simplification

* Remove need for autoformat

* Do not duplicate indent, but guess continuation indent

* Make collapse of assertThat part of best practices

* Only retain last newline

---------

Co-authored-by: Tim te Beek <tim@moderne.io>
Co-authored-by: Tim te Beek <timtebeek@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
 ...llapseConsecutiveAssertThatStatements.java | 130 +++++++
 .../resources/META-INF/rewrite/assertj.yml    |   2 +
 ...seConsecutiveAssertThatStatementsTest.java | 320 ++++++++++++++++++
 3 files changed, 452 insertions(+)
 create mode 100644 src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java
 create mode 100644 src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java

diff --git a/src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java b/src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java
new file mode 100644
index 000000000..478474ef8
--- /dev/null
+++ b/src/main/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatements.java
@@ -0,0 +1,130 @@
+/*
+ * Copyright 2024 the original author or authors.
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * https://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.java.testing.assertj;
+
+import org.openrewrite.*;
+import org.openrewrite.java.JavaIsoVisitor;
+import org.openrewrite.java.MethodMatcher;
+import org.openrewrite.java.search.SemanticallyEqual;
+import org.openrewrite.java.search.UsesMethod;
+import org.openrewrite.java.tree.*;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+@Incubating(since = "2.17.0")
+public class CollapseConsecutiveAssertThatStatements extends Recipe {
+    private static final MethodMatcher ASSERT_THAT = new MethodMatcher("org.assertj.core.api.Assertions assertThat(..)");
+
+    @Override
+    public String getDisplayName() {
+        return "Collapse consecutive `assertThat` statements";
+    }
+
+    @Override
+    public String getDescription() {
+        return "Collapse consecutive `assertThat` statements into single `assertThat` chained statement. This recipe ignores `assertThat` statements that have method invocation as parameter.";
+    }
+
+    @Override
+    public TreeVisitor<?, ExecutionContext> getVisitor() {
+        return Preconditions.check(new UsesMethod<>(ASSERT_THAT), new JavaIsoVisitor<ExecutionContext>() {
+            @Override
+            public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
+                J.Block bl = super.visitBlock(block, ctx);
+
+                List<Statement> statementsCollapsed = new ArrayList<>();
+                for (List<Statement> group : getGroupedStatements(bl)) {
+                    if (group.size() <= 1) {
+                        statementsCollapsed.addAll(group);
+                    } else {
+                        statementsCollapsed.add(getCollapsedAssertThat(group));
+                    }
+                }
+
+                return bl.withStatements(statementsCollapsed);
+            }
+
+            private List<List<Statement>> getGroupedStatements(J.Block bl) {
+                List<Statement> originalStatements = bl.getStatements();
+                List<List<Statement>> groupedStatements = new ArrayList<>();
+                Expression currentActual = null; // The actual argument of the current group of assertThat statements
+                List<Statement> currentGroup = new ArrayList<>();
+                for (Statement statement : originalStatements) {
+                    if (statement instanceof J.MethodInvocation) {
+                        J.MethodInvocation assertion = (J.MethodInvocation) statement;
+                        if (isGroupableAssertion(assertion)) {
+                            J.MethodInvocation assertThat = (J.MethodInvocation) assertion.getSelect();
+                            assert assertThat != null;
+                            Expression actual = assertThat.getArguments().get(0);
+                            if (currentActual == null || !SemanticallyEqual.areEqual(currentActual, actual)) {
+                                // Conclude the previous group
+                                groupedStatements.add(currentGroup);
+                                currentGroup = new ArrayList<>();
+                                currentActual = actual;
+                            }
+                            currentGroup.add(statement);
+                            continue;
+                        }
+                    }
+
+                    // Conclude the previous group, and start a new group
+                    groupedStatements.add(currentGroup);
+                    currentGroup = new ArrayList<>();
+                    currentActual = null;
+                    // The current statement should not be grouped with any other statement
+                    groupedStatements.add(Collections.singletonList(statement));
+                }
+                if (!currentGroup.isEmpty()) {
+                    // Conclude the last group
+                    groupedStatements.add(currentGroup);
+                }
+                return groupedStatements;
+            }
+
+            private boolean isGroupableAssertion(J.MethodInvocation assertion) {
+                // Only match method invocations where the select is an assertThat, containing a non-method call argument
+                if (ASSERT_THAT.matches(assertion.getSelect())) {
+                    J.MethodInvocation assertThat = (J.MethodInvocation) assertion.getSelect();
+                    if (assertThat != null && !(assertThat.getArguments().get(0) instanceof MethodCall)) {
+                        return TypeUtils.isOfType(assertThat.getType(), assertion.getType());
+                    }
+                }
+                return false;
+            }
+
+            private J.MethodInvocation getCollapsedAssertThat(List<Statement> consecutiveAssertThatStatement) {
+                assert !consecutiveAssertThatStatement.isEmpty();
+                Space originalPrefix = consecutiveAssertThatStatement.get(0).getPrefix();
+                String continuationIndent = originalPrefix.getIndent().contains("\t") ? "\t\t" : "        ";
+                Space indentedNewline = Space.format(originalPrefix.getLastWhitespace().replaceAll("^\\s+\n", "\n") +
+                                                     continuationIndent);
+                J.MethodInvocation collapsed = null;
+                for (Statement st : consecutiveAssertThatStatement) {
+                    J.MethodInvocation assertion = (J.MethodInvocation) st;
+                    J.MethodInvocation assertThat = (J.MethodInvocation) assertion.getSelect();
+                    assert assertThat != null;
+                    J.MethodInvocation newSelect = collapsed == null ? assertThat : collapsed;
+                    collapsed = assertion.getPadding().withSelect(JRightPadded
+                            .build((Expression) newSelect.withPrefix(Space.EMPTY))
+                            .withAfter(indentedNewline));
+                }
+                return collapsed.withPrefix(originalPrefix);
+            }
+        });
+    }
+}
diff --git a/src/main/resources/META-INF/rewrite/assertj.yml b/src/main/resources/META-INF/rewrite/assertj.yml
index 7dcfc91b9..28b7eae0b 100644
--- a/src/main/resources/META-INF/rewrite/assertj.yml
+++ b/src/main/resources/META-INF/rewrite/assertj.yml
@@ -44,6 +44,8 @@ recipeList:
   - tech.picnic.errorprone.refasterrules.AssertJStringRulesRecipes
   - tech.picnic.errorprone.refasterrules.AssertJThrowingCallableRulesRecipes
 
+  - org.openrewrite.java.testing.assertj.CollapseConsecutiveAssertThatStatements
+
 ---
 type: specs.openrewrite.org/v1beta/recipe
 name: org.openrewrite.java.testing.assertj.StaticImports
diff --git a/src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java b/src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java
new file mode 100644
index 000000000..fdf25e35a
--- /dev/null
+++ b/src/test/java/org/openrewrite/java/testing/assertj/CollapseConsecutiveAssertThatStatementsTest.java
@@ -0,0 +1,320 @@
+/*
+ * Copyright 2024 the original author or authors.
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * https://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.java.testing.assertj;
+
+import org.junit.jupiter.api.Test;
+import org.openrewrite.DocumentExample;
+import org.openrewrite.InMemoryExecutionContext;
+import org.openrewrite.java.JavaParser;
+import org.openrewrite.test.RecipeSpec;
+import org.openrewrite.test.RewriteTest;
+
+import static org.openrewrite.java.Assertions.java;
+
+class CollapseConsecutiveAssertThatStatementsTest implements RewriteTest {
+
+    @Override
+    public void defaults(RecipeSpec spec) {
+        spec
+          .parser(JavaParser.fromJavaVersion()
+            .classpathFromResources(new InMemoryExecutionContext(), "assertj-core-3.24"))
+          .recipe(new CollapseConsecutiveAssertThatStatements());
+    }
+
+    @DocumentExample
+    @Test
+    void collapseIfConsecutiveAssertThatPresent() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      assertThat(listA).isNotNull();
+                      assertThat(listA).hasSize(3);
+                      assertThat(listA).containsExactly("a", "b", "c");
+                  }
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """,
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      assertThat(listA)
+                              .isNotNull()
+                              .hasSize(3)
+                              .containsExactly("a", "b", "c");
+                  }
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """
+          )
+        );
+    }
+
+    @Test
+    void collapseIfMultipleConsecutiveAssertThatPresent() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      // Comment nor whitespace below duplicated
+                      assertThat(listA).isNotNull();
+                      assertThat(listA).hasSize(3);
+                      assertThat(listA).containsExactly("a", "b", "c");
+
+                      List<String> listB = Arrays.asList("a", "b", "c");
+
+                      assertThat(listB).isNotNull();
+                      assertThat(listB).hasSize(3);
+                  }
+
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """,
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      // Comment nor whitespace below duplicated
+                      assertThat(listA)
+                              .isNotNull()
+                              .hasSize(3)
+                              .containsExactly("a", "b", "c");
+
+                      List<String> listB = Arrays.asList("a", "b", "c");
+
+                      assertThat(listB)
+                              .isNotNull()
+                              .hasSize(3);
+                  }
+
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """
+          )
+        );
+    }
+
+    @Test
+    void collapseIfMultipleConsecutiveAssertThatPresent2() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest2 {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      assertThat(listA).isNotNull();
+                      assertThat(listA).hasSize(3);
+                      List<String> listB = Arrays.asList("a", "b", "c");
+                      assertThat(listA).containsExactly("a", "b", "c");
+                      assertThat(listB).isNotNull();
+                      assertThat(listB).hasSize(3);
+                  }
+
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """,
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest2 {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      assertThat(listA)
+                              .isNotNull()
+                              .hasSize(3);
+                      List<String> listB = Arrays.asList("a", "b", "c");
+                      assertThat(listA).containsExactly("a", "b", "c");
+                      assertThat(listB)
+                              .isNotNull()
+                              .hasSize(3);
+                  }
+
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """
+          )
+        );
+    }
+
+    @Test
+    void ignoreIfAssertThatOnDifferentVariables() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      List<String> listB = Arrays.asList("a", "b", "c");
+                      assertThat(listA).isNotNull();
+                      assertThat(listB).containsExactly("a", "b", "c");
+                  }
+
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """
+          )
+        );
+    }
+
+    @Test
+    void ignoreIfAssertThatOnMethodInvocation() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      assertThat(notification()).isNotNull();
+                      assertThat(notification()).isTrue();
+                  }
+                  private boolean notification() {
+                      return true;
+                  }
+              }
+              """
+          )
+        );
+    }
+
+    @Test
+    void ignoreIfAssertThatChainExists() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      assertThat(listA).containsExactly("a", "b", "c");
+                      assertThat(listA)
+                          .isNotNull()
+                          .hasSize(3);
+                      assertThat(listA).containsExactly("a", "b", "c");
+                  }
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """
+          )
+        );
+    }
+
+    @Test
+    void ignoreIfStatementPresentBetweenTwoAssertThat() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import java.util.Arrays;
+              import java.util.List;
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class MyTest {
+                  void test() {
+                      List<String> listA = Arrays.asList("a", "b", "c");
+                      assertThat(listA).isNotNull();
+                      int x=3;
+                      assertThat(listA).hasSize(x);
+                  }
+                  private int[] notification() {
+                      return new int[]{1, 2, 3};
+                  }
+              }
+              """
+          )
+        );
+    }
+
+    @Test
+    void ignoreIncorrectUseOfExtracting() {
+        //language=java
+        rewriteRun(
+          java(
+            """
+              import static org.assertj.core.api.Assertions.assertThat;
+
+              class Node { Node parent; Node getParent() { return parent; } }
+
+              class MyTest {
+                  // Should not collapse these two, even if `extracting` is used incorrectly
+                  void b(Node node) {
+                      assertThat(node).extracting(Node::getParent);
+                      assertThat(node).isNotNull();
+                  }
+              }
+              """
+          )
+        );
+    }
+}