From b19c4ff870cce3300eea24dc9ebe96e2167977cb Mon Sep 17 00:00:00 2001
From: Ritchie1108 <ritchie1108@gmail.com>
Date: Wed, 3 Apr 2024 16:15:09 -0400
Subject: [PATCH] refactor: utilize incremental search in AccessibleChildren
 calls (#50)

---
 src/IAccessibleUtils.h | 147 +++++++++++++++++++----------------------
 1 file changed, 67 insertions(+), 80 deletions(-)

diff --git a/src/IAccessibleUtils.h b/src/IAccessibleUtils.h
index 34e29be..8108cf9 100644
--- a/src/IAccessibleUtils.h
+++ b/src/IAccessibleUtils.h
@@ -5,7 +5,6 @@
 #pragma comment(lib, "oleacc.lib")
 
 #include <wrl/client.h>
-#include <queue>
 #include <thread>
 
 using NodePtr = Microsoft::WRL::ComPtr<IAccessible>;
@@ -92,8 +91,7 @@ long GetAccessibleState(NodePtr node) {
 }
 
 template <typename Function>
-void TraversalAccessible(NodePtr node, Function f, bool rawTraversal = false,
-                         bool useBFS = false) {
+void TraversalAccessible(NodePtr node, Function f) {
   if (!node)
     return;
 
@@ -101,73 +99,65 @@ void TraversalAccessible(NodePtr node, Function f, bool rawTraversal = false,
   if (S_OK != node->get_accChildCount(&childCount) || !childCount)
     return;
 
-  std::vector<VARIANT> varChildren(childCount);
-  if (S_OK != AccessibleChildren(node.Get(), 0, childCount, varChildren.data(),
-                                 &childCount))
-    return;
-
-  if (useBFS) {
-    BFSTraversal(node, f, rawTraversal, varChildren);
-  } else {
-    DFSTraversal(node, f, rawTraversal, varChildren);
-  }
-}
-
-void BFSTraversal(NodePtr node, std::function<bool(NodePtr)> f,
-                  bool rawTraversal, std::vector<VARIANT> varChildren) {
-  std::queue<NodePtr> queue;
-  for (const auto& varChild : varChildren) {
-    if (varChild.vt != VT_DISPATCH)
-      continue;
-
-    Microsoft::WRL::ComPtr<IDispatch> dispatch = varChild.pdispVal;
-    NodePtr child = nullptr;
-    if (S_OK != dispatch->QueryInterface(IID_IAccessible, (void**)&child))
-      continue;
-
-    queue.push(child);
-  }
+  auto nStep = childCount < 20 ? childCount : 20;
+  for (auto i = 0; i < childCount;) {
+    auto arrChildren = std::make_unique<VARIANT[]>(nStep);
 
-  while (!queue.empty()) {
-    NodePtr current = queue.front();
-    queue.pop();
+    long nGetCount = 0;
+    if (S_OK == AccessibleChildren(node.Get(), i, nStep, arrChildren.get(),
+                                   &nGetCount)) {
 
-    if (rawTraversal) {
-      if (f(current))
-        break;
+      bool bDone = false;
+      for (int j = 0; j < nGetCount; ++j) {
+        if (arrChildren[j].vt != VT_DISPATCH) {
+          continue;
+        }
+        if (bDone) {
+          arrChildren[j].pdispVal->Release(); // 立刻释放,避免内存泄漏
+          continue;
+        }
 
-      long childCount = 0;
-      if (S_OK == current->get_accChildCount(&childCount) && childCount > 0) {
-        std::vector<VARIANT> varChildren(childCount);
-        if (S_OK == AccessibleChildren(current.Get(), 0, childCount,
-                                       varChildren.data(), &childCount)) {
-          for (const auto& varChild : varChildren) {
-            if (varChild.vt != VT_DISPATCH)
-              continue;
-
-            Microsoft::WRL::ComPtr<IDispatch> dispatch = varChild.pdispVal;
-            NodePtr child = nullptr;
-            if (S_OK !=
-                dispatch->QueryInterface(IID_IAccessible, (void**)&child))
-              continue;
-
-            queue.push(child);
+        Microsoft::WRL::ComPtr<IDispatch> pDispatch = arrChildren[j].pdispVal;
+        NodePtr pChild = nullptr;
+        if (S_OK ==
+            pDispatch->QueryInterface(IID_IAccessible, (void**)&pChild)) {
+          if ((GetAccessibleState(pChild) & STATE_SYSTEM_INVISIBLE) == 0) {
+            if (f(pChild)) {
+              bDone = true;
+            }
           }
         }
       }
-    } else {
-      if ((GetAccessibleState(current) & STATE_SYSTEM_INVISIBLE) ==
-          0)  // 只遍历可见节点
-      {
-        if (f(current))
-          break;
+
+      if (bDone) {
+        return;
       }
     }
+
+    i += nStep;
+
+    if (i + nStep >= childCount) {
+      nStep = childCount - i;
+    }
   }
 }
 
-void DFSTraversal(NodePtr node, std::function<bool(NodePtr)> f,
-                  bool rawTraversal, std::vector<VARIANT> varChildren) {
+// 原 TraversalAccessible 函数,现在被使用步进的新函数替代,只保留 Raw 遍历功能
+template <typename Function>
+void TraversalRawAccessible(NodePtr node, Function f,
+                            bool rawTraversal = false) {
+  if (!node)
+    return;
+
+  long childCount = 0;
+  if (S_OK != node->get_accChildCount(&childCount) || !childCount)
+    return;
+
+  std::vector<VARIANT> varChildren(childCount);
+  if (S_OK != AccessibleChildren(node.Get(), 0, childCount, varChildren.data(),
+                                 &childCount))
+    return;
+
   for (const auto& varChild : varChildren) {
     if (varChild.vt != VT_DISPATCH)
       continue;
@@ -178,7 +168,7 @@ void DFSTraversal(NodePtr node, std::function<bool(NodePtr)> f,
       continue;
 
     if (rawTraversal) {
-      TraversalAccessible(child, f, true);
+      TraversalRawAccessible(child, f, true);
       if (f(child))
         break;
     } else {
@@ -482,7 +472,7 @@ bool IsOnBookmark(NodePtr top, POINT pt) {
   }
 
   bool flag = false;
-  TraversalAccessible(
+  TraversalRawAccessible(
       top,
       [&flag, &pt](NodePtr child) {
         if (GetAccessibleRole(child) != ROLE_SYSTEM_PUSHBUTTON) {
@@ -526,28 +516,25 @@ bool IsOnMenuBookmark(NodePtr top, POINT pt) {
   }
 
   bool flag = false;
-  TraversalAccessible(
-      MenuBarPane,
-      [&flag, &pt](NodePtr child) {
-        if (GetAccessibleRole(child) != ROLE_SYSTEM_MENUITEM) {
-          return false;
-        }
+  TraversalAccessible(MenuBarPane, [&flag, &pt](NodePtr child) {
+    if (GetAccessibleRole(child) != ROLE_SYSTEM_MENUITEM) {
+      return false;
+    }
 
-        GetAccessibleSize(child, [&flag, &pt, &child](RECT rect) {
-          if (!PtInRect(&rect, pt)) {
-            return;
-          }
+    GetAccessibleSize(child, [&flag, &pt, &child](RECT rect) {
+      if (!PtInRect(&rect, pt)) {
+        return;
+      }
 
-          GetAccessibleDescription(child, [&flag](BSTR bstr) {
-            std::wstring_view bstr_view(bstr);
-            flag = bstr_view.find_first_of(L".:") != std::wstring_view::npos &&
-                   bstr_view.substr(0, 11) != L"javascript:";
-          });
-        });
+      GetAccessibleDescription(child, [&flag](BSTR bstr) {
+        std::wstring_view bstr_view(bstr);
+        flag = bstr_view.find_first_of(L".:") != std::wstring_view::npos &&
+               bstr_view.substr(0, 11) != L"javascript:";
+      });
+    });
 
-        return flag;
-      },
-      false, true);  // useBFS
+    return flag;
+  });
 
   return flag;
 }