From 2766a21681c1e395d4bcd4c0f178a2627cf56d23 Mon Sep 17 00:00:00 2001
From: Lukasz Dorau <lukasz.dorau@intel.com>
Date: Wed, 2 Oct 2024 11:02:12 +0200
Subject: [PATCH] Clear tracker for the current pool on destroy

Clear tracker for the current pool on destroy.
Do not print error messages if provider
does not support the free() operation.

Signed-off-by: Lukasz Dorau <lukasz.dorau@intel.com>
---
 src/memory_pool.c                |  8 +++-
 src/provider/provider_tracking.c | 76 ++++++++++++++++++++------------
 src/provider/provider_tracking.h |  3 +-
 test/memoryPoolAPI.cpp           |  6 ++-
 test/pools/disjoint_pool.cpp     |  5 ++-
 5 files changed, 64 insertions(+), 34 deletions(-)

diff --git a/src/memory_pool.c b/src/memory_pool.c
index 7d65acf36..f6ae8841f 100644
--- a/src/memory_pool.c
+++ b/src/memory_pool.c
@@ -41,8 +41,12 @@ static umf_result_t umfPoolCreateInternal(const umf_memory_pool_ops_t *ops,
     assert(ops->version == UMF_VERSION_CURRENT);
 
     if (!(flags & UMF_POOL_CREATE_FLAG_DISABLE_TRACKING)) {
-        // wrap provider with memory tracking provider
-        ret = umfTrackingMemoryProviderCreate(provider, pool, &pool->provider);
+        // Wrap provider with memory tracking provider.
+        // Check if the provider supports the free() operation.
+        bool upstreamDoesNotFree = (umfMemoryProviderFree(provider, NULL, 0) ==
+                                    UMF_RESULT_ERROR_NOT_SUPPORTED);
+        ret = umfTrackingMemoryProviderCreate(provider, pool, &pool->provider,
+                                              upstreamDoesNotFree);
         if (ret != UMF_RESULT_SUCCESS) {
             goto err_provider_create;
         }
diff --git a/src/provider/provider_tracking.c b/src/provider/provider_tracking.c
index 5e1db9d14..83aa5c335 100644
--- a/src/provider/provider_tracking.c
+++ b/src/provider/provider_tracking.c
@@ -141,6 +141,9 @@ typedef struct umf_tracking_memory_provider_t {
     umf_memory_tracker_handle_t hTracker;
     umf_memory_pool_handle_t pool;
     critnib *ipcCache;
+
+    // the upstream provider does not support the free() operation
+    bool upstreamDoesNotFree;
 } umf_tracking_memory_provider_t;
 
 typedef struct umf_tracking_memory_provider_t umf_tracking_memory_provider_t;
@@ -392,9 +395,11 @@ static umf_result_t trackingInitialize(void *params, void **ret) {
     return UMF_RESULT_SUCCESS;
 }
 
-#ifndef NDEBUG
-static void check_if_tracker_is_empty(umf_memory_tracker_handle_t hTracker,
-                                      umf_memory_pool_handle_t pool) {
+// TODO clearing the tracker is a temporary solution and should be removed.
+// The tracker should be cleared using the provider's free() operation.
+static void clear_tracker_for_the_pool(umf_memory_tracker_handle_t hTracker,
+                                       umf_memory_pool_handle_t pool,
+                                       bool upstreamDoesNotFree) {
     uintptr_t rkey;
     void *rvalue;
     size_t n_items = 0;
@@ -403,39 +408,55 @@ static void check_if_tracker_is_empty(umf_memory_tracker_handle_t hTracker,
     while (1 == critnib_find((critnib *)hTracker->map, last_key, FIND_G, &rkey,
                              &rvalue)) {
         tracker_value_t *value = (tracker_value_t *)rvalue;
-        if (value->pool == pool || pool == NULL) {
-            n_items++;
+        if (value->pool != pool && pool != NULL) {
+            last_key = rkey;
+            continue;
         }
 
+        n_items++;
+
+        void *removed_value = critnib_remove(hTracker->map, rkey);
+        assert(removed_value == rvalue);
+        umf_ba_free(hTracker->tracker_allocator, removed_value);
+
         last_key = rkey;
     }
 
-    if (n_items) {
-        // Do not assert if we are running in the proxy library,
-        // because it may need those resources till
-        // the very end of exiting the application.
-        if (!utils_is_running_in_proxy_lib()) {
-            if (pool) {
-                LOG_ERR("tracking provider of pool %p is not empty! "
-                        "(%zu items left)",
-                        (void *)pool, n_items);
-            } else {
-                LOG_ERR("tracking provider is not empty! (%zu items "
-                        "left)",
-                        n_items);
-            }
+#ifndef NDEBUG
+    // print error messages only if provider supports the free() operation
+    if (n_items && !upstreamDoesNotFree) {
+        if (pool) {
+            LOG_ERR(
+                "tracking provider of pool %p is not empty! (%zu items left)",
+                (void *)pool, n_items);
+        } else {
+            LOG_ERR("tracking provider is not empty! (%zu items left)",
+                    n_items);
         }
     }
+#else  /* DEBUG */
+    (void)upstreamDoesNotFree; // unused in DEBUG build
+    (void)n_items;             // unused in DEBUG build
+#endif /* DEBUG */
+}
+
+static void clear_tracker(umf_memory_tracker_handle_t hTracker) {
+    clear_tracker_for_the_pool(hTracker, NULL, false);
 }
-#endif /* NDEBUG */
 
 static void trackingFinalize(void *provider) {
     umf_tracking_memory_provider_t *p =
         (umf_tracking_memory_provider_t *)provider;
+
     critnib_delete(p->ipcCache);
-#ifndef NDEBUG
-    check_if_tracker_is_empty(p->hTracker, p->pool);
-#endif /* NDEBUG */
+
+    // Do not clear the tracker if we are running in the proxy library,
+    // because it may need those resources till
+    // the very end of exiting the application.
+    if (!utils_is_running_in_proxy_lib()) {
+        clear_tracker_for_the_pool(p->hTracker, p->pool,
+                                   p->upstreamDoesNotFree);
+    }
 
     umf_ba_global_free(provider);
 }
@@ -661,10 +682,11 @@ umf_memory_provider_ops_t UMF_TRACKING_MEMORY_PROVIDER_OPS = {
 
 umf_result_t umfTrackingMemoryProviderCreate(
     umf_memory_provider_handle_t hUpstream, umf_memory_pool_handle_t hPool,
-    umf_memory_provider_handle_t *hTrackingProvider) {
+    umf_memory_provider_handle_t *hTrackingProvider, bool upstreamDoesNotFree) {
 
     umf_tracking_memory_provider_t params;
     params.hUpstream = hUpstream;
+    params.upstreamDoesNotFree = upstreamDoesNotFree;
     params.hTracker = TRACKER;
     if (!params.hTracker) {
         LOG_ERR("failed, TRACKER is NULL");
@@ -739,16 +761,14 @@ void umfMemoryTrackerDestroy(umf_memory_tracker_handle_t handle) {
         return;
     }
 
-    // Do not destroy if we are running in the proxy library,
+    // Do not destroy the tracket if we are running in the proxy library,
     // because it may need those resources till
     // the very end of exiting the application.
     if (utils_is_running_in_proxy_lib()) {
         return;
     }
 
-#ifndef NDEBUG
-    check_if_tracker_is_empty(handle, NULL);
-#endif /* NDEBUG */
+    clear_tracker(handle);
 
     // We have to zero all inner pointers,
     // because the tracker handle can be copied
diff --git a/src/provider/provider_tracking.h b/src/provider/provider_tracking.h
index f020c3da8..9444ee475 100644
--- a/src/provider/provider_tracking.h
+++ b/src/provider/provider_tracking.h
@@ -11,6 +11,7 @@
 #define UMF_MEMORY_TRACKER_INTERNAL_H 1
 
 #include <assert.h>
+#include <stdbool.h>
 #include <stdlib.h>
 
 #include <umf/base.h>
@@ -53,7 +54,7 @@ umf_result_t umfMemoryTrackerGetAllocInfo(const void *ptr,
 // forwards all requests to hUpstream memory Provider. hUpstream lifetime should be managed by the user of this function.
 umf_result_t umfTrackingMemoryProviderCreate(
     umf_memory_provider_handle_t hUpstream, umf_memory_pool_handle_t hPool,
-    umf_memory_provider_handle_t *hTrackingProvider);
+    umf_memory_provider_handle_t *hTrackingProvider, bool upstreamDoesNotFree);
 
 void umfTrackingMemoryProviderGetUpstreamProvider(
     umf_memory_provider_handle_t hTrackingProvider,
diff --git a/test/memoryPoolAPI.cpp b/test/memoryPoolAPI.cpp
index 0fb2a4422..96fd634c6 100644
--- a/test/memoryPoolAPI.cpp
+++ b/test/memoryPoolAPI.cpp
@@ -139,7 +139,8 @@ TEST_P(umfPoolWithCreateFlagsTest, memoryPoolWithCustomProvider) {
 }
 
 TEST_F(test, retrieveMemoryProvider) {
-    umf_memory_provider_handle_t provider = (umf_memory_provider_handle_t)0x1;
+    auto nullProvider = umf_test::wrapProviderUnique(nullProviderCreate());
+    umf_memory_provider_handle_t provider = nullProvider.get();
 
     auto pool =
         wrapPoolUnique(createPoolChecked(umfProxyPoolOps(), provider, nullptr));
@@ -258,7 +259,8 @@ TEST_P(poolInitializeTest, errorPropagation) {
 }
 
 TEST_F(test, retrieveMemoryProvidersError) {
-    umf_memory_provider_handle_t provider = (umf_memory_provider_handle_t)0x1;
+    auto nullProvider = umf_test::wrapProviderUnique(nullProviderCreate());
+    umf_memory_provider_handle_t provider = nullProvider.get();
 
     auto pool =
         wrapPoolUnique(createPoolChecked(umfProxyPoolOps(), provider, nullptr));
diff --git a/test/pools/disjoint_pool.cpp b/test/pools/disjoint_pool.cpp
index d7612f4d5..2f5d61142 100644
--- a/test/pools/disjoint_pool.cpp
+++ b/test/pools/disjoint_pool.cpp
@@ -82,7 +82,10 @@ TEST_F(test, sharedLimits) {
         }
         umf_result_t free(void *ptr, [[maybe_unused]] size_t size) noexcept {
             ::free(ptr);
-            numFrees++;
+            // umfMemoryProviderFree(provider, NULL, 0) is called inside umfPoolCreateInternal()
+            if (ptr != NULL && size != 0) {
+                numFrees++;
+            }
             return UMF_RESULT_SUCCESS;
         }
     };