From 667a2e737849dd39919fd70114e932594e083c84 Mon Sep 17 00:00:00 2001 From: Keith Horton Date: Mon, 25 Sep 2023 23:21:33 -0700 Subject: [PATCH] Implementing a simple registry watcher, which maintains the same contract as the unique_threadpool objects it uses for the callbacks - it guarantees all callbacks have completed and no more will be invoked when the d'tor has completed --- include/wil/registry.h | 204 ++++++++++++++++++++++++++++++++++++---- tests/RegistryTests.cpp | 124 ++++++++++++++++++++++++ 2 files changed, 309 insertions(+), 19 deletions(-) diff --git a/include/wil/registry.h b/include/wil/registry.h index 1f324b289..3b18447c5 100644 --- a/include/wil/registry.h +++ b/include/wil/registry.h @@ -3109,6 +3109,8 @@ namespace wil } }; + constexpr DWORD registry_notify_filter = REG_NOTIFY_CHANGE_LAST_SET | REG_NOTIFY_CHANGE_NAME | REG_NOTIFY_THREAD_AGNOSTIC; + inline void delete_registry_watcher_state(_In_opt_ registry_watcher_state* watcherStorage) { watcherStorage->Release(); } typedef resource_policy + class slim_registry_watcher_t + { + public: + // HRESULT or void error handling + typedef typename err_policy::result result; + + slim_registry_watcher_t() WI_NOEXCEPT = default; + + // Exception-based constructors + slim_registry_watcher_t(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, ::wistd::function&& callback) + { + static_assert(::wistd::is_same::value, "this constructor requires exceptions; use the create method"); + create(rootKey, subKey, isRecursive, ::wistd::move(callback)); + } + + slim_registry_watcher_t(::wil::unique_hkey&& keyToWatch, bool isRecursive, ::wistd::function&& callback) + { + static_assert(::wistd::is_same::value, "this constructor requires exceptions; use the create method"); + create(::wistd::move(keyToWatch), isRecursive, ::wistd::move(callback)); + } + + // Pass a root key, sub key pair or use an empty string to use rootKey as the key to watch. + result create(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, ::wistd::function&& callback) + { + ::wil::unique_hkey keyToWatch; + HRESULT hr = HRESULT_FROM_WIN32(::RegCreateKeyExW(rootKey, subKey, 0, nullptr, 0, KEY_NOTIFY, nullptr, &keyToWatch, nullptr)); + if (FAILED(hr)) + { + return err_policy::HResult(hr); + } + return err_policy::HResult(create_common(::wistd::move(keyToWatch), isRecursive, ::wistd::move(callback))); + } + + result create(::wil::unique_hkey&& keyToWatch, bool isRecursive, ::wistd::function&& callback) + { + return err_policy::HResult(create_common(::wistd::move(keyToWatch), isRecursive, ::wistd::move(callback))); + } + + private: + // using the default d'tor, destruction must occur in this order + ::wistd::function m_callback; + ::wil::unique_hkey m_keyToWatch; + ::wil::unique_event_nothrow m_eventHandle; + ::wil::unique_threadpool_wait m_threadPoolWait; + bool m_isRecursive; + + static void __stdcall callback(PTP_CALLBACK_INSTANCE, void* context, TP_WAIT*, TP_WAIT_RESULT) WI_NOEXCEPT + { + const auto this_ptr = static_cast(context); + + const LSTATUS error = ::RegNotifyChangeKeyValue( + this_ptr->m_keyToWatch.get(), this_ptr->m_isRecursive, ::wil::details::registry_notify_filter, this_ptr->m_eventHandle.get(), TRUE); + + // Call the client before re-arming to ensure that multiple callbacks don't + // run concurrently. + switch (error) + { + case ERROR_SUCCESS: + case ERROR_ACCESS_DENIED: + // Normal modification: send RegistryChangeKind::Modify and re-arm. + this_ptr->m_callback(::wil::RegistryChangeKind::Modify); + ::SetThreadpoolWait(this_ptr->m_threadPoolWait.get(), this_ptr->m_eventHandle.get(), nullptr); + break; + + case ERROR_KEY_DELETED: + // Key deleted: send RegistryChangeKind::Delete but do not re-arm. + this_ptr->m_callback(::wil::RegistryChangeKind::Delete); + break; + + case ERROR_HANDLE_REVOKED: + // Handle revoked. This can occur if the user session ends before the watcher shuts-down. + // Do not re-arm since there is generally no way to respond. + break; + + default: + // failure here is a programming error. + FAIL_FAST_HR(HRESULT_FROM_WIN32(error)); + } + } + + // This function exists to avoid template expansion of this code based on err_policy. + HRESULT create_common(::wil::unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) WI_NOEXCEPT + { + RETURN_IF_FAILED(m_eventHandle.create()); + + m_threadPoolWait.reset(CreateThreadpoolWait(&slim_registry_watcher_t::callback, this, nullptr)); + RETURN_LAST_ERROR_IF(!m_threadPoolWait); + + // associate the notification handle with the threadpool before passing it to RegNotifyChangeKeyValue so we get immediate callbacks in the tp + SetThreadpoolWait(m_threadPoolWait.get(), m_eventHandle.get(), nullptr); + + // 'this' object must be fully created before calling RegNotifyChangeKeyValue, as callbacks can start immediately + m_keyToWatch = wistd::move(keyToWatch); + m_isRecursive = isRecursive; + m_callback = wistd::move(callback); + + // no failures after RegNotifyChangeKeyValue succeeds, + RETURN_IF_WIN32_ERROR(RegNotifyChangeKeyValue(m_keyToWatch.get(), m_isRecursive, ::wil::details::registry_notify_filter, m_eventHandle.get(), TRUE)); + return S_OK; + } + }; + template class registry_watcher_t : public storage_t { @@ -3173,9 +3278,9 @@ namespace wil // using auto reset event so don't need to manually reset. // failure here is a programming error. - const LSTATUS error = RegNotifyChangeKeyValue(watcherState->m_keyToWatch.get(), watcherState->m_isRecursive, - REG_NOTIFY_CHANGE_LAST_SET | REG_NOTIFY_CHANGE_NAME | REG_NOTIFY_THREAD_AGNOSTIC, - watcherState->m_eventHandle.get(), TRUE); + const LSTATUS error = RegNotifyChangeKeyValue( + watcherState->m_keyToWatch.get(), watcherState->m_isRecursive, + ::wil::details::registry_notify_filter, watcherState->m_eventHandle.get(), TRUE); // Call the client before re-arming to ensure that multiple callbacks don't // run concurrently. @@ -3213,9 +3318,9 @@ namespace wil wistd::move(keyToWatch), isRecursive, wistd::move(callback))); RETURN_IF_NULL_ALLOC(watcherState); RETURN_IF_FAILED(watcherState->m_eventHandle.create()); - RETURN_IF_WIN32_ERROR(RegNotifyChangeKeyValue(watcherState->m_keyToWatch.get(), - watcherState->m_isRecursive, REG_NOTIFY_CHANGE_LAST_SET | REG_NOTIFY_CHANGE_NAME | REG_NOTIFY_THREAD_AGNOSTIC, - watcherState->m_eventHandle.get(), TRUE)); + RETURN_IF_WIN32_ERROR(RegNotifyChangeKeyValue( + watcherState->m_keyToWatch.get(), watcherState->m_isRecursive, + ::wil::details::registry_notify_filter, watcherState->m_eventHandle.get(), TRUE)); watcherState->m_threadPoolWait.reset(CreateThreadpoolWait(®istry_watcher_t::callback, watcherState.get(), nullptr)); RETURN_LAST_ERROR_IF(!watcherState->m_threadPoolWait); @@ -3225,44 +3330,105 @@ namespace wil } }; - typedef unique_any_t, err_returncode_policy>> unique_registry_watcher_nothrow; - typedef unique_any_t, err_failfast_policy>> unique_registry_watcher_failfast; + typedef ::wil::unique_any_t, err_returncode_policy>> unique_registry_watcher_nothrow; + typedef ::wil::unique_any_t, err_failfast_policy>> unique_registry_watcher_failfast; + + typedef ::wistd::unique_ptr<::wil::slim_registry_watcher_t> unique_slim_registry_watcher_nothrow; + typedef ::wistd::unique_ptr<::wil::slim_registry_watcher_t> unique_slim_registry_watcher_failfast; - inline unique_registry_watcher_nothrow make_registry_watcher_nothrow(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) WI_NOEXCEPT + inline ::wil::unique_registry_watcher_nothrow make_registry_watcher_nothrow(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) WI_NOEXCEPT { unique_registry_watcher_nothrow watcher; watcher.create(rootKey, subKey, isRecursive, wistd::move(callback)); return watcher; // caller must test for success using if (watcher) } + inline ::wil::unique_slim_registry_watcher_nothrow make_slim_registry_watcher_nothrow(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) WI_NOEXCEPT + { + auto watcher = wil::make_unique_nothrow<::wil::slim_registry_watcher_t>(); + if (watcher) + { + if (FAILED(watcher->create(rootKey, subKey, isRecursive, ::wistd::move(callback)))) + { + watcher.reset(); + } + } + // caller must test for success using if (watcher) + return watcher; + } - inline unique_registry_watcher_nothrow make_registry_watcher_nothrow(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) WI_NOEXCEPT + inline ::wil::unique_registry_watcher_nothrow make_registry_watcher_nothrow(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) WI_NOEXCEPT { unique_registry_watcher_nothrow watcher; watcher.create(wistd::move(keyToWatch), isRecursive, wistd::move(callback)); return watcher; // caller must test for success using if (watcher) } + inline ::wil::unique_slim_registry_watcher_nothrow make_slim_registry_watcher_nothrow(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) WI_NOEXCEPT + { + auto watcher = wil::make_unique_nothrow<::wil::slim_registry_watcher_t>(); + if (watcher) + { + watcher->create(::wistd::move(keyToWatch), isRecursive, ::wistd::move(callback)); + } + // caller must test for success using if (watcher) + return watcher; + } - inline unique_registry_watcher_failfast make_registry_watcher_failfast(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) + inline ::wil::unique_registry_watcher_failfast make_registry_watcher_failfast(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) { - return unique_registry_watcher_failfast(rootKey, subKey, isRecursive, wistd::move(callback)); + return ::wil::unique_registry_watcher_failfast(rootKey, subKey, isRecursive, wistd::move(callback)); + } + inline ::wil::unique_slim_registry_watcher_failfast make_slim_registry_watcher_failfast(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) + { + auto watcher = wil::make_unique_failfast<::wil::slim_registry_watcher_t>(); + if (watcher) + { + watcher->create(rootKey, subKey, isRecursive, ::wistd::move(callback)); + } + // caller must test for success using if (watcher) + return watcher; } - inline unique_registry_watcher_failfast make_registry_watcher_failfast(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) + inline ::wil::unique_registry_watcher_failfast make_registry_watcher_failfast(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) { - return unique_registry_watcher_failfast(wistd::move(keyToWatch), isRecursive, wistd::move(callback)); + return ::wil::unique_registry_watcher_failfast(wistd::move(keyToWatch), isRecursive, wistd::move(callback)); + } + inline ::wil::unique_slim_registry_watcher_failfast make_slim_registry_watcher_failfast(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) + { + auto watcher = wil::make_unique_failfast<::wil::slim_registry_watcher_t>(); + if (watcher) + { + watcher->create(::wistd::move(keyToWatch), isRecursive, ::wistd::move(callback)); + } + // caller must test for success using if (watcher) + return watcher; } #ifdef WIL_ENABLE_EXCEPTIONS - typedef unique_any_t, err_exception_policy >> unique_registry_watcher; + typedef ::wil::unique_any_t, err_exception_policy >> unique_registry_watcher; + typedef ::wistd::unique_ptr<::wil::slim_registry_watcher_t> unique_slim_registry_watcher; - inline unique_registry_watcher make_registry_watcher(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) + inline ::wil::unique_registry_watcher make_registry_watcher(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) { - return unique_registry_watcher(rootKey, subKey, isRecursive, wistd::move(callback)); + return ::wil::unique_registry_watcher(rootKey, subKey, isRecursive, wistd::move(callback)); + } + inline ::wil::unique_slim_registry_watcher make_slim_registry_watcher(HKEY rootKey, _In_ PCWSTR subKey, bool isRecursive, wistd::function&& callback) + { + auto watcher = wil::make_unique_nothrow<::wil::slim_registry_watcher_t>(); + THROW_IF_NULL_ALLOC(watcher.get()); + watcher->create(rootKey, subKey, isRecursive, ::wistd::move(callback)); + return watcher; } - inline unique_registry_watcher make_registry_watcher(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) + inline ::wil::unique_registry_watcher make_registry_watcher(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) + { + return ::wil::unique_registry_watcher(wistd::move(keyToWatch), isRecursive, wistd::move(callback)); + } + inline ::wil::unique_slim_registry_watcher make_slim_registry_watcher(unique_hkey&& keyToWatch, bool isRecursive, wistd::function&& callback) { - return unique_registry_watcher(wistd::move(keyToWatch), isRecursive, wistd::move(callback)); + auto watcher = wil::make_unique_nothrow<::wil::slim_registry_watcher_t>(); + THROW_IF_NULL_ALLOC(watcher.get()); + watcher->create(::wistd::move(keyToWatch), isRecursive, ::wistd::move(callback)); + return watcher; } #endif // WIL_ENABLE_EXCEPTIONS } // namespace wil diff --git a/tests/RegistryTests.cpp b/tests/RegistryTests.cpp index 187c9c07b..4e2c0bb63 100644 --- a/tests/RegistryTests.cpp +++ b/tests/RegistryTests.cpp @@ -5195,3 +5195,127 @@ TEST_CASE("BasicRegistryTests::key_heap_string_nothrow_iterator", "[registry]]") REQUIRE(count == 4); } } +TEST_CASE("BasicRegistryTests::slim_registry_watcher_t", "[registry]]") +{ + const auto deleteHr = HRESULT_FROM_WIN32(::RegDeleteTreeW(HKEY_CURRENT_USER, testSubkey)); + if (deleteHr != HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) + { + REQUIRE_SUCCEEDED(deleteHr); + } + + SECTION("unique_slim_registry_watcher_nothrow fails to be created") + { + // will fail if we pass invalid values - the substring must not be null + wil::unique_hkey hkey; + REQUIRE_SUCCEEDED(wil::reg::create_unique_key_nothrow(HKEY_CURRENT_USER, testSubkey, hkey)); + const auto watcher = wil::make_slim_registry_watcher_nothrow(hkey.get(), nullptr, true, [&](wil::RegistryChangeKind) {}); + REQUIRE(watcher.get() == nullptr); + } + + SECTION("unique_slim_registry_watcher_nothrow with recurssive changes") + { + wil::unique_hkey hkey; + REQUIRE_SUCCEEDED(wil::reg::create_unique_key_nothrow(HKEY_CURRENT_USER, testSubkey, hkey, wil::reg::key_access::readwrite)); + + wil::unique_event_nothrow callbackTracking; + REQUIRE_SUCCEEDED(callbackTracking.create()); + + uint32_t modifyCount = 0; + uint32_t deleteCount = 0; + const auto watcher = wil::make_slim_registry_watcher_nothrow(HKEY_CURRENT_USER, testSubkey, true, [&](wil::RegistryChangeKind kind) + { + switch (kind) + { + case wil::RegistryChangeKind::Modify: + ++modifyCount; + break; + case wil::RegistryChangeKind::Delete: + ++deleteCount; + break; + } + callbackTracking.SetEvent(); + }); + REQUIRE(watcher.get() != nullptr); + + for (uint32_t count = 0; count < 5; ++count) + { + callbackTracking.ResetEvent(); + REQUIRE_SUCCEEDED(wil::reg::set_value_nothrow(hkey.get(), L"test", count)); + REQUIRE(callbackTracking.wait(500)); + } + REQUIRE(modifyCount == 5); + + for (uint32_t count = 0; count < 5; ++count) + { + callbackTracking.ResetEvent(); + wil::unique_hkey embeddedKey; + REQUIRE_SUCCEEDED(wil::reg::create_unique_key_nothrow(hkey.get(), L"test\\test", embeddedKey)); + REQUIRE(callbackTracking.wait(500)); + + callbackTracking.ResetEvent(); + REQUIRE(::RegDeleteKeyW(hkey.get(), L"test\\test") == ERROR_SUCCESS); + REQUIRE(callbackTracking.wait(500)); + } + + // RegCreateKeyExW the first time had 2x callbacks + REQUIRE(modifyCount == 16); + + callbackTracking.ResetEvent(); + ::RegDeleteValueW(hkey.get(), L"test"); + REQUIRE(callbackTracking.wait(500)); + REQUIRE(modifyCount == 17); + + callbackTracking.ResetEvent(); + REQUIRE(::RegDeleteKeyW(hkey.get(), L"test") == ERROR_SUCCESS); + REQUIRE(callbackTracking.wait(500)); + REQUIRE(modifyCount == 18); + + callbackTracking.ResetEvent(); + REQUIRE(::RegDeleteKeyW(HKEY_CURRENT_USER, testSubkey) == ERROR_SUCCESS); + REQUIRE(callbackTracking.wait(500)); + REQUIRE(deleteCount == 1); + + // after deleting the key, should not have any more callbacks + callbackTracking.ResetEvent(); + REQUIRE_SUCCEEDED(wil::reg::create_unique_key_nothrow(HKEY_CURRENT_USER, testSubkey, hkey, wil::reg::key_access::readwrite)); + callbackTracking.wait(500); + REQUIRE(modifyCount == 18); + REQUIRE(deleteCount == 1); + + callbackTracking.ResetEvent(); + REQUIRE_SUCCEEDED(wil::reg::set_value_nothrow(hkey.get(), L"test", 0)); + callbackTracking.wait(500); + REQUIRE(modifyCount == 18); + REQUIRE(deleteCount == 1); + } + + SECTION("unique_slim_registry_watcher_nothrow guaranteeing d'tor waits on callbacks") + { + wil::unique_hkey hkey; + REQUIRE_SUCCEEDED(wil::reg::create_unique_key_nothrow(HKEY_CURRENT_USER, testSubkey, hkey, wil::reg::key_access::readwrite)); + + wil::unique_event_nothrow callbackEntered; + REQUIRE_SUCCEEDED(callbackEntered.create()); + + auto watcher = wil::make_slim_registry_watcher_nothrow(HKEY_CURRENT_USER, testSubkey, true, [&](wil::RegistryChangeKind) + { + callbackEntered.SetEvent(); + // now wait 5 seconds - ensuring we are in the d'tor on the main thread + Sleep(5000); + }); + REQUIRE(watcher.get() != nullptr); + + // initiate a change then destroy the watcher - it must wait for all callbacks + callbackEntered.ResetEvent(); + REQUIRE_SUCCEEDED(wil::reg::set_value_nothrow(hkey.get(), L"test", 0)); + REQUIRE(callbackEntered.wait(500)); + + // now we know we are in the callback - destroy the watcher + const auto startTime = GetTickCount64(); + watcher.reset(); + const auto endTime = GetTickCount64(); + + // should have waited ~5 seconds + REQUIRE(endTime - startTime > 4000); + } +}