From 4d4b70ac2d479a9ff60e8b7f05ab03783c412598 Mon Sep 17 00:00:00 2001 From: Andrei Litvin Date: Tue, 17 May 2022 20:58:08 -0400 Subject: [PATCH] Use incremental resolution for minmdns discover/resolve (#18442) * Added an incremental minmdns resolver (squash merge of several commits) * Add header file to BUILD.gn * Add chip license * Restyle * Comment update * Add reset to incremental resolve for the common data as well * Remove unused constant * Simplify unit test logic: use records for data serialization and parsing * Updated tests - we now test commissionable nodes as well * Undo changes to Resolver.h * Cleaner naming for the txt parser delegate * Use clearer naming for SRV target - it is called target in RFC so use target host name * Code review update: switch detail logging to tracing events * Code review comments * Updated comments based on code review * More code reivew comments * Fix unit test off by one error * Update based on code review comment * Add string support class for unit testing qnames * Updated usage of Full names - less hardcoding * Name "bit flags" as such to make it clear they have to be bit flags * Update flag setting to be obviously bitflags * Added comments to TestQName helper class * Restyle * Fix unit test build rules * RAII for reset on init, add test for this * Restyle * Make linter happy: no else after return * Some changes to try to use incremental resolver * Code compiles, added some logic that should mostly cover except actual AAAA requests not available * Remove PTR parsing for commissioning - there seems to be no use for this right now * Get rid of mDiscoveryType * Compile works with some logic for AAAA requesting. Not timeouts for AAAA though * Move ActiveResolveAttempt types to Variant. Will start expanding with AAAA query support, so switching to a slightly more extensible way of storing values. * Fix unit test * HeapQName addition * Updated schedule retries. Still need marking * Implement AAAA fetching * Pass on interface id for IP addresses * Use a constant for parallel resolve count * Fix gni * Add expiry logic for SRV resolution * Start adding support for marking IP address resolution completed * Mark AAAA query resolution done * Remove unused constant * Restyle * Fix misplaced return for resolver initialization * Fix typo in message * Remove empty qname test: compiler complains about 0 with 0 comparison * Remove one more unused constant * Restyle * Fix typos * Fix typos * Switch a header-only list from static_library to source_set. Darwin refuses to compile a static library without cpp sources * initialize element count in HealQName * initialize element count in HealQName * Update python unit tests a bit - say when killing the app on purpose in logs, better log coloring and logic (do not hardcode binary bits and rely on modules * Ensure resolverproxy clears up after itself in the destructor: should clear any delegates set to an object about to get deleted * Proper shutdown of resolverproxy in platform implementation * Add a log when test script exits with non-zero exit code * Add more logging to try to help debug repl tests * Add support for script-gdb for python repl scripts, to give a backtrace if a test crashes * Restyle * Fix typo in python test run split * More operationla resolve cleanup. ResolverProxy seems to break MinMdns because of dangling pointers, only patched it up however usage of this object should be removed * Remove usage for script-gdb for yaml tests. Leave that for local runs only * Only unregister the commisionable delegate * Remove some internal debug methods * Remove extra log that shows up during chip tool test list --- .github/workflows/tests.yaml | 2 +- scripts/requirements.txt | 3 + scripts/tests/run_python_test.py | 45 +- .../AbstractDnssdDiscoveryController.h | 2 +- .../CHIPCommissionableNodeController.cpp | 5 + .../CHIPCommissionableNodeController.h | 2 +- .../python/test/test_scripts/base.py | 5 + src/lib/core/BUILD.gn | 1 + src/lib/core/CHIPConfig.h | 13 + src/lib/core/core.gni | 3 + src/lib/dnssd/ActiveResolveAttempts.cpp | 48 +- src/lib/dnssd/ActiveResolveAttempts.h | 41 +- src/lib/dnssd/Discovery_ImplPlatform.cpp | 6 + src/lib/dnssd/IncrementalResolve.cpp | 98 +-- src/lib/dnssd/IncrementalResolve.h | 23 +- src/lib/dnssd/ResolverProxy.h | 11 +- src/lib/dnssd/Resolver_ImplMinimalMdns.cpp | 824 +++++++++--------- src/lib/dnssd/Resolver_ImplNone.cpp | 5 + src/lib/dnssd/minimal_mdns/core/HeapQName.h | 166 ++++ .../dnssd/minimal_mdns/core/tests/BUILD.gn | 4 +- .../minimal_mdns/core/tests/TestHeapQName.cpp | 110 +++ src/lib/dnssd/platform/tests/TestPlatform.cpp | 1 + .../dnssd/tests/TestIncrementalResolve.cpp | 2 +- src/lib/support/ScopedBuffer.h | 2 +- 24 files changed, 886 insertions(+), 536 deletions(-) create mode 100644 src/lib/dnssd/minimal_mdns/core/HeapQName.h create mode 100644 src/lib/dnssd/minimal_mdns/core/tests/TestHeapQName.cpp diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b6ab370a4c6849..cc50432aa0c0ee 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -272,7 +272,7 @@ jobs: - name: Run Tests timeout-minutes: 30 run: | - scripts/run_in_build_env.sh './scripts/tests/run_python_test.py --app out/linux-x64-all-clusters-no-ble-no-wifi-tsan-clang-test/chip-all-clusters-app --factoryreset --script-args "-t 3600 --disable-test ClusterObjectTests.TestTimedRequestTimeout"' + scripts/run_in_build_env.sh './scripts/tests/run_python_test.py --app out/linux-x64-all-clusters-no-ble-no-wifi-tsan-clang-test/chip-all-clusters-app --factoryreset --script-args "--log-level INFO -t 3600 --disable-test ClusterObjectTests.TestTimedRequestTimeout"' - name: Uploading core files uses: actions/upload-artifact@v2 if: ${{ failure() }} && ${{ !env.ACT }} diff --git a/scripts/requirements.txt b/scripts/requirements.txt index 9a255fc1ec30ae..6aa7d817b37154 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -55,3 +55,6 @@ lark stringcase cryptography + +# python unit tests +colorama diff --git a/scripts/tests/run_python_test.py b/scripts/tests/run_python_test.py index b94188d72f7a2a..7ebd4b8729d990 100755 --- a/scripts/tests/run_python_test.py +++ b/scripts/tests/run_python_test.py @@ -14,19 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pty -import subprocess import click +import coloredlogs +import datetime +import logging import os import pathlib -import typing +import pty import queue -import threading +import shlex +import signal +import subprocess import sys +import threading import time -import datetime -import shlex -import logging +import typing + +from colorama import Fore, Style DEFAULT_CHIP_ROOT = os.path.abspath( os.path.join(os.path.dirname(__file__), '..', '..')) @@ -58,9 +62,9 @@ def RedirectQueueThread(fp, tag, queue) -> threading.Thread: def DumpProgramOutputToQueue(thread_list: typing.List[threading.Thread], tag: str, process: subprocess.Popen, queue: queue.Queue): thread_list.append(RedirectQueueThread(process.stdout, - (f"[{tag}][\33[33mSTDOUT\33[0m]").encode(), queue)) + (f"[{tag}][{Fore.YELLOW}STDOUT{Style.RESET_ALL}]").encode(), queue)) thread_list.append(RedirectQueueThread(process.stderr, - (f"[{tag}][\33[31mSTDERR\33[0m]").encode(), queue)) + (f"[{tag}][{Fore.RED}STDERR{Style.RESET_ALL}]").encode(), queue)) @click.command() @@ -69,12 +73,15 @@ def DumpProgramOutputToQueue(thread_list: typing.List[threading.Thread], tag: st @click.option("--app-args", type=str, default='', help='The extra arguments passed to the device.') @click.option("--script", type=click.Path(exists=True), default=os.path.join(DEFAULT_CHIP_ROOT, 'src', 'controller', 'python', 'test', 'test_scripts', 'mobile-device-test.py'), help='Test script to use.') @click.option("--script-args", type=str, default='', help='Path to the test script to use, omit to use the default test script (mobile-device-test.py).') -def main(app: str, factoryreset: bool, app_args: str, script: str, script_args: str): +@click.option("--script-gdb", is_flag=True, help='Run script through gdb') +def main(app: str, factoryreset: bool, app_args: str, script: str, script_args: str, script_gdb: bool): if factoryreset: retcode = subprocess.call("rm -rf /tmp/chip* /tmp/repl*", shell=True) if retcode != 0: raise Exception("Failed to remove /tmp/chip* for factory reset.") + coloredlogs.install(level='INFO') + log_queue = queue.Queue() log_cooking_threads = [] @@ -88,21 +95,31 @@ def main(app: str, factoryreset: bool, app_args: str, script: str, script_args: app_process = subprocess.Popen( app_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0) DumpProgramOutputToQueue( - log_cooking_threads, "\33[34mAPP \33[0m", app_process, log_queue) + log_cooking_threads, Fore.GREEN + "APP " + Style.RESET_ALL, app_process, log_queue) - script_command = ["/usr/bin/env", "python3", script, "--paa-trust-store-path", os.path.join(DEFAULT_CHIP_ROOT, MATTER_DEVELOPMENT_PAA_ROOT_CERTS), + script_command = [script, "--paa-trust-store-path", os.path.join(DEFAULT_CHIP_ROOT, MATTER_DEVELOPMENT_PAA_ROOT_CERTS), '--log-format', '%(message)s'] + shlex.split(script_args) + + if script_gdb: + script_command = "gdb -batch -return-child-result -q -ex run -ex bt --args python3".split() + script_command + else: + script_command = "/usr/bin/env python3".split() + script_command + logging.info(f"Execute: {script_command}") test_script_process = subprocess.Popen( script_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - DumpProgramOutputToQueue(log_cooking_threads, "\33[32mTEST\33[0m", + DumpProgramOutputToQueue(log_cooking_threads, Fore.GREEN + "TEST" + Style.RESET_ALL, test_script_process, log_queue) test_script_exit_code = test_script_process.wait() + if test_script_exit_code != 0: + logging.error("Test script exited with error %r" % test_script_exit_code) + test_app_exit_code = 0 if app_process: - app_process.send_signal(2) + logging.warning("Stopping app with SIGINT") + app_process.send_signal(signal.SIGINT.value) test_app_exit_code = app_process.wait() # There are some logs not cooked, so we wait until we have processed all logs. diff --git a/src/controller/AbstractDnssdDiscoveryController.h b/src/controller/AbstractDnssdDiscoveryController.h index 3372aab1f67bf8..55e148f68d21af 100644 --- a/src/controller/AbstractDnssdDiscoveryController.h +++ b/src/controller/AbstractDnssdDiscoveryController.h @@ -40,7 +40,7 @@ class DLL_EXPORT AbstractDnssdDiscoveryController : public Dnssd::CommissioningR { public: AbstractDnssdDiscoveryController() {} - ~AbstractDnssdDiscoveryController() override {} + ~AbstractDnssdDiscoveryController() override { mDNSResolver.Shutdown(); } void OnNodeDiscovered(const chip::Dnssd::DiscoveredNodeData & nodeData) override; diff --git a/src/controller/CHIPCommissionableNodeController.cpp b/src/controller/CHIPCommissionableNodeController.cpp index 47f1299c07ac6f..b2b8b31d317a69 100644 --- a/src/controller/CHIPCommissionableNodeController.cpp +++ b/src/controller/CHIPCommissionableNodeController.cpp @@ -48,6 +48,11 @@ CHIP_ERROR CommissionableNodeController::DiscoverCommissioners(Dnssd::DiscoveryF return mResolver->FindCommissioners(discoveryFilter); } +CommissionableNodeController::~CommissionableNodeController() +{ + mDNSResolver.SetCommissioningDelegate(nullptr); +} + const Dnssd::DiscoveredNodeData * CommissionableNodeController::GetDiscoveredCommissioner(int idx) { return GetDiscoveredNode(idx); diff --git a/src/controller/CHIPCommissionableNodeController.h b/src/controller/CHIPCommissionableNodeController.h index 53b77776d72abf..f6bf2532b53b09 100644 --- a/src/controller/CHIPCommissionableNodeController.h +++ b/src/controller/CHIPCommissionableNodeController.h @@ -37,7 +37,7 @@ class DLL_EXPORT CommissionableNodeController : public AbstractDnssdDiscoveryCon { public: CommissionableNodeController(chip::Dnssd::Resolver * resolver = nullptr) : mResolver(resolver) {} - ~CommissionableNodeController() override {} + ~CommissionableNodeController() override; CHIP_ERROR DiscoverCommissioners(Dnssd::DiscoveryFilter discoveryFilter = Dnssd::DiscoveryFilter()); diff --git a/src/controller/python/test/test_scripts/base.py b/src/controller/python/test/test_scripts/base.py index ca2f1419dd4abe..df2c33ed1a9c20 100644 --- a/src/controller/python/test/test_scripts/base.py +++ b/src/controller/python/test/test_scripts/base.py @@ -381,6 +381,8 @@ async def TestMultiFabric(self, ip: str, setuppin: int, nodeid: int): ChipDeviceCtrl.ChipDeviceController.ShutdownAll() chip.FabricAdmin.FabricAdmin.ShutdownAll() + self.logger.info("Shutdown completed, starting new controllers...") + self.fabricAdmin = chip.FabricAdmin.FabricAdmin( fabricId=1, fabricIndex=1) fabricAdmin2 = chip.FabricAdmin.FabricAdmin(fabricId=2, fabricIndex=2) @@ -390,6 +392,8 @@ async def TestMultiFabric(self, ip: str, setuppin: int, nodeid: int): self.devCtrl2 = fabricAdmin2.NewController( self.controllerNodeId, self.paaTrustStorePath) + self.logger.info("Waiting for attribute reads...") + data1 = await self.devCtrl.ReadAttribute(nodeid, [(Clusters.OperationalCredentials.Attributes.NOCs)], fabricFiltered=False) data2 = await self.devCtrl2.ReadAttribute(nodeid, [(Clusters.OperationalCredentials.Attributes.NOCs)], fabricFiltered=False) @@ -414,6 +418,7 @@ async def TestMultiFabric(self, ip: str, setuppin: int, nodeid: int): "Got back fabric indices that match for two different fabrics!") return False + self.logger.info("Attribute reads completed...") return True async def TestFabricSensitive(self, nodeid: int): diff --git a/src/lib/core/BUILD.gn b/src/lib/core/BUILD.gn index 70e12672891c81..147e589d045ff5 100644 --- a/src/lib/core/BUILD.gn +++ b/src/lib/core/BUILD.gn @@ -61,6 +61,7 @@ buildconfig_header("chip_buildconfig") { "CHIP_CONFIG_TRANSPORT_TRACE_ENABLED=${chip_enable_transport_trace}", "CHIP_CONFIG_TRANSPORT_PW_TRACE_ENABLED=${chip_enable_transport_pw_trace}", "CHIP_CONFIG_MINMDNS_DYNAMIC_OPERATIONAL_RESPONDER_LIST=${chip_config_minmdns_dynamic_operational_responder_list}", + "CHIP_CONFIG_MINMDNS_MAX_PARALLEL_RESOLVES=${chip_config_minmdns_max_parallel_resolves}", ] } diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h index 3e247c77eb5fe4..2dbbc867a817b0 100644 --- a/src/lib/core/CHIPConfig.h +++ b/src/lib/core/CHIPConfig.h @@ -1210,6 +1210,19 @@ extern const char CHIP_NON_PRODUCTION_MARKER[]; #define CHIP_CONFIG_MINMDNS_DYNAMIC_OPERATIONAL_RESPONDER_LIST 0 #endif // CHIP_CONFIG_MINMDNS_DYNAMIC_OPERATIONAL_RESPONDER_LIST +/* + * @def CHIP_CONFIG_MINMDNS_MAX_PARALLEL_RESOLVES + * + * @brief Determines the maximum number of SRV records that can be processed in parallel. + * Affects maximum number of results received for browse requests + * (where a single packet may contain multiple SRV entries) + * or number of pending resolves that still require a AAAA IP record + * to be resolved. + */ +#ifndef CHIP_CONFIG_MINMDNS_MAX_PARALLEL_RESOLVES +#define CHIP_CONFIG_MINMDNS_MAX_PARALLEL_RESOLVES 2 +#endif // CHIP_CONFIG_MINMDNS_MAX_PARALLEL_RESOLVES + /* * @def CHIP_CONFIG_NETWORK_COMMISSIONING_DEBUG_TEXT_BUFFER_SIZE * diff --git a/src/lib/core/core.gni b/src/lib/core/core.gni index f2e78db68a14dc..3034061342546b 100644 --- a/src/lib/core/core.gni +++ b/src/lib/core/core.gni @@ -64,6 +64,9 @@ declare_args() { # of tracking information for operational advertisement. chip_config_minmdns_dynamic_operational_responder_list = current_os == "linux" || current_os == "android" || current_os == "darwin" + + # When using minmdns, set the number of parallel resolves + chip_config_minmdns_max_parallel_resolves = 2 } if (chip_target_style == "") { diff --git a/src/lib/dnssd/ActiveResolveAttempts.cpp b/src/lib/dnssd/ActiveResolveAttempts.cpp index c645c0ed98fc22..60d8347d58117f 100644 --- a/src/lib/dnssd/ActiveResolveAttempts.cpp +++ b/src/lib/dnssd/ActiveResolveAttempts.cpp @@ -65,19 +65,34 @@ void ActiveResolveAttempts::Complete(const chip::Dnssd::DiscoveredNodeData & dat } } +void ActiveResolveAttempts::CompleteIpResolution(SerializedQNameIterator targetHostName) +{ + for (auto & item : mRetryQueue) + { + if (item.attempt.MatchesIpResolve(targetHostName)) + { + item.attempt.Clear(); + return; + } + } +} + void ActiveResolveAttempts::MarkPending(const chip::PeerId & peerId) { - ScheduledAttempt attempt(peerId, /* firstSend */ true); - MarkPending(attempt); + MarkPending(ScheduledAttempt(peerId, /* firstSend */ true)); } void ActiveResolveAttempts::MarkPending(const chip::Dnssd::DiscoveryFilter & filter, const chip::Dnssd::DiscoveryType type) { - ScheduledAttempt attempt(filter, type, /* firstSend */ true); - MarkPending(attempt); + MarkPending(ScheduledAttempt(filter, type, /* firstSend */ true)); +} + +void ActiveResolveAttempts::MarkPending(ScheduledAttempt::IpResolve && resolve) +{ + MarkPending(ScheduledAttempt(std::move(resolve), /* firstSend */ true)); } -void ActiveResolveAttempts::MarkPending(const ScheduledAttempt & attempt) +void ActiveResolveAttempts::MarkPending(ScheduledAttempt && attempt) { // Strategy when picking the peer id to use: // 1 if a matching peer id is already found, use that one @@ -211,5 +226,28 @@ Optional ActiveResolveAttempts::NextSch return Optional::Missing(); } +bool ActiveResolveAttempts::IsWaitingForIpResolutionFor(SerializedQNameIterator hostName) const +{ + for (auto & entry : mRetryQueue) + { + if (entry.attempt.IsEmpty()) + { + continue; // not a pending item + } + + if (!entry.attempt.IsIpResolve()) + { + continue; + } + + if (hostName == entry.attempt.IpResolveData().hostName.Content()) + { + return true; + } + } + + return false; +} + } // namespace Minimal } // namespace mdns diff --git a/src/lib/dnssd/ActiveResolveAttempts.h b/src/lib/dnssd/ActiveResolveAttempts.h index 7c71279c34f5d8..7ce121319ce3f2 100644 --- a/src/lib/dnssd/ActiveResolveAttempts.h +++ b/src/lib/dnssd/ActiveResolveAttempts.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -61,6 +62,12 @@ class ActiveResolveAttempts Resolve(chip::PeerId id) : peerId(id) {} }; + struct IpResolve + { + HeapQName hostName; + IpResolve(HeapQName && host) : hostName(std::move(host)) {} + }; + ScheduledAttempt() {} ScheduledAttempt(const chip::PeerId & peer, bool first) : resolveData(chip::InPlaceTemplateType(), peer), firstSend(first) @@ -68,6 +75,11 @@ class ActiveResolveAttempts ScheduledAttempt(const chip::Dnssd::DiscoveryFilter discoveryFilter, const chip::Dnssd::DiscoveryType type, bool first) : resolveData(chip::InPlaceTemplateType(), discoveryFilter, type), firstSend(first) {} + + ScheduledAttempt(IpResolve && ipResolve, bool first) : + resolveData(chip::InPlaceTemplateType(), ipResolve), firstSend(first) + {} + bool operator==(const ScheduledAttempt & other) const { return Matches(other) && other.firstSend == firstSend; } bool Matches(const ScheduledAttempt & other) const { @@ -99,9 +111,25 @@ class ActiveResolveAttempts return a.peerId == b.peerId; } + + if (resolveData.Is()) + { + if (!other.resolveData.Is()) + { + return false; + } + auto & a = resolveData.Get(); + auto & b = other.resolveData.Get(); + + return a.hostName == b.hostName; + } return false; } + bool MatchesIpResolve(SerializedQNameIterator hostName) const + { + return resolveData.Is() && (hostName == resolveData.Get().hostName.Content()); + } bool Matches(const chip::PeerId & peer) const { return resolveData.Is() && (resolveData.Get().peerId == peer); @@ -146,15 +174,18 @@ class ActiveResolveAttempts return false; } } + bool IsEmpty() const { return !resolveData.Valid(); } bool IsResolve() const { return resolveData.Is(); } bool IsBrowse() const { return resolveData.Is(); } + bool IsIpResolve() const { return resolveData.Is(); } void Clear() { resolveData = DataType(); } const Browse & BrowseData() const { return resolveData.Get(); } const Resolve & ResolveData() const { return resolveData.Get(); } + const IpResolve & IpResolveData() const { return resolveData.Get(); } - using DataType = chip::Variant; + using DataType = chip::Variant; DataType resolveData; @@ -171,6 +202,7 @@ class ActiveResolveAttempts /// Mark a resolution as a success, removing it from the internal list void Complete(const chip::PeerId & peerId); void Complete(const chip::Dnssd::DiscoveredNodeData & data); + void CompleteIpResolution(SerializedQNameIterator targetHostName); /// Mark that a resolution is pending, adding it to the internal list /// @@ -178,6 +210,7 @@ class ActiveResolveAttempts /// by NextScheduled (potentially with others as well) void MarkPending(const chip::PeerId & peerId); void MarkPending(const chip::Dnssd::DiscoveryFilter & filter, const chip::Dnssd::DiscoveryType type); + void MarkPending(ScheduledAttempt::IpResolve && resolve); // Get minimum time until the next pending reply is required. // @@ -194,6 +227,10 @@ class ActiveResolveAttempts // any peer that needs a new request sent chip::Optional NextScheduled(); + /// Check if any of the pending queries are for the given host name for + /// IP resolution. + bool IsWaitingForIpResolutionFor(SerializedQNameIterator hostName) const; + private: struct RetryEntry { @@ -211,7 +248,7 @@ class ActiveResolveAttempts // least a factor of two chip::System::Clock::Timeout nextRetryDelay = chip::System::Clock::Seconds16(1); }; - void MarkPending(const ScheduledAttempt & attempt); + void MarkPending(ScheduledAttempt && attempt); chip::System::Clock::ClockBase * mClock; RetryEntry mRetryQueue[kRetryQueueSize]; }; diff --git a/src/lib/dnssd/Discovery_ImplPlatform.cpp b/src/lib/dnssd/Discovery_ImplPlatform.cpp index bf06eef224f1f2..42c5d8c04b4701 100644 --- a/src/lib/dnssd/Discovery_ImplPlatform.cpp +++ b/src/lib/dnssd/Discovery_ImplPlatform.cpp @@ -349,6 +349,7 @@ CHIP_ERROR DiscoveryImplPlatform::InitImpl() void DiscoveryImplPlatform::Shutdown() { VerifyOrReturn(mDnssdInitialized); + mResolverProxy.Shutdown(); ChipDnssdShutdown(); } @@ -646,6 +647,11 @@ CHIP_ERROR ResolverProxy::ResolveNodeId(const PeerId & peerId, Inet::IPAddressTy return ChipDnssdResolve(&service, Inet::InterfaceId::Null(), HandleNodeIdResolve, mDelegate); } +ResolverProxy::~ResolverProxy() +{ + Shutdown(); +} + CHIP_ERROR ResolverProxy::FindCommissionableNodes(DiscoveryFilter filter) { VerifyOrReturnError(mDelegate != nullptr, CHIP_ERROR_INCORRECT_STATE); diff --git a/src/lib/dnssd/IncrementalResolve.cpp b/src/lib/dnssd/IncrementalResolve.cpp index 3de282f0f838b7..2ac6ffd2f0484e 100644 --- a/src/lib/dnssd/IncrementalResolve.cpp +++ b/src/lib/dnssd/IncrementalResolve.cpp @@ -60,13 +60,9 @@ enum class ServiceNameType }; // Common prefix to check for all operational/commissioner/commissionable name parts -constexpr QNamePart kOperationalSuffix[] = { kOperationalServiceName, kOperationalProtocol, kLocalDomain }; -constexpr QNamePart kCommissionableSuffix[] = { kCommissionableServiceName, kCommissionProtocol, kLocalDomain }; -constexpr QNamePart kCommissionerSuffix[] = { kCommissionerServiceName, kCommissionProtocol, kLocalDomain }; -constexpr QNamePart kCommissionableSubTypeSuffix[] = { kSubtypeServiceNamePart, kCommissionableServiceName, kCommissionProtocol, - kLocalDomain }; -constexpr QNamePart kCommissionerSubTypeSuffix[] = { kSubtypeServiceNamePart, kCommissionerServiceName, kCommissionProtocol, - kLocalDomain }; +constexpr QNamePart kOperationalSuffix[] = { kOperationalServiceName, kOperationalProtocol, kLocalDomain }; +constexpr QNamePart kCommissionableSuffix[] = { kCommissionableServiceName, kCommissionProtocol, kLocalDomain }; +constexpr QNamePart kCommissionerSuffix[] = { kCommissionerServiceName, kCommissionProtocol, kLocalDomain }; ServiceNameType ComputeServiceNameType(SerializedQNameIterator name) { @@ -100,18 +96,6 @@ ServiceNameType ComputeServiceNameType(SerializedQNameIterator name) return ServiceNameType::kInvalid; } -/// Checks if the name is of the form ._sub._matter(c|d)._udp.local -bool IsCommissionSubtype(SerializedQNameIterator name) -{ - if (!name.Next() || !name.IsValid()) - { - // subtype should be a prefix - return false; - } - - return (name == kCommissionerSubTypeSuffix) || (name == kCommissionableSubTypeSuffix); -} - /// Automatically resets a IncrementalResolver to inactive in its destructor /// unless disarmed. /// @@ -207,6 +191,17 @@ CHIP_ERROR IncrementalResolver::InitializeParsing(mdns::Minimal::SerializedQName case ServiceNameType::kCommissioner: case ServiceNameType::kCommissionable: mSpecificResolutionData.Set(); + + { + // Commission addresses start with instance name + SerializedQNameIterator nameCopy = name; + if (!nameCopy.Next() || !nameCopy.IsValid()) + { + return CHIP_ERROR_INVALID_ARGUMENT; + } + + Platform::CopyString(mSpecificResolutionData.Get().instanceName, nameCopy.Value()); + } break; default: return CHIP_ERROR_INVALID_ARGUMENT; @@ -235,7 +230,7 @@ IncrementalResolver::RequiredInformationFlags IncrementalResolver::GetMissingReq return flags; } -CHIP_ERROR IncrementalResolver::OnRecord(const ResourceData & data, BytesRange packetRange) +CHIP_ERROR IncrementalResolver::OnRecord(Inet::InterfaceId interface, const ResourceData & data, BytesRange packetRange) { MATTER_TRACE_EVENT_SCOPE("Incremental resolver record parsing"); // measure until loop finished @@ -246,8 +241,6 @@ CHIP_ERROR IncrementalResolver::OnRecord(const ResourceData & data, BytesRange p switch (data.GetType()) { - case QType::PTR: - return OnPtrRecord(data, packetRange); case QType::TXT: if (data.GetName() != mRecordName.Get()) { @@ -268,7 +261,7 @@ CHIP_ERROR IncrementalResolver::OnRecord(const ResourceData & data, BytesRange p return CHIP_ERROR_INVALID_ARGUMENT; } - return OnIpAddress(addr); + return OnIpAddress(interface, addr); } case QType::AAAA: { if (data.GetName() != mTargetHostName.Get()) @@ -283,7 +276,7 @@ CHIP_ERROR IncrementalResolver::OnRecord(const ResourceData & data, BytesRange p return CHIP_ERROR_INVALID_ARGUMENT; } - return OnIpAddress(addr); + return OnIpAddress(interface, addr); } case QType::SRV: // SRV handled on creation, ignored for 'additional data' default: @@ -294,50 +287,6 @@ CHIP_ERROR IncrementalResolver::OnRecord(const ResourceData & data, BytesRange p return CHIP_NO_ERROR; } -CHIP_ERROR IncrementalResolver::OnPtrRecord(const ResourceData & data, BytesRange packetRange) -{ - // Here we handle subtype expectations. Data is of the form: - // ._sub._mattrc._udp.local or - // ._sub._mattrd._udp.local - // - // If these hold, then we have to check if PTR points at the current record and - // if yes, the subtype matches and information can be extracted. - - if (!IsActiveCommissionParse()) - { - MATTER_TRACE_EVENT_INSTANT("PTR for non-commission"); - return CHIP_NO_ERROR; - } - - if (!IsCommissionSubtype(data.GetName())) - { - MATTER_TRACE_EVENT_INSTANT("PTR without a commission subtype"); - return CHIP_NO_ERROR; - } - - SerializedQNameIterator qname; - - if (!ParsePtrRecord(data.GetData(), packetRange, &qname)) - { - return CHIP_ERROR_INVALID_ARGUMENT; - } - - if (qname != mRecordName.Get()) - { - MATTER_TRACE_EVENT_INSTANT("PTR not applicable"); - return CHIP_NO_ERROR; - } - - // TODO: why are we not validating the string here? what is the purpose - // of copying and preserving the instance name here? - Platform::CopyString(mSpecificResolutionData.Get().instanceName, qname.Value()); - - // TODO: nothing is done with the sub name here. The instance name could be - // fetched from the SRV record, so why are we processing PTR records? - - return CHIP_NO_ERROR; -} - CHIP_ERROR IncrementalResolver::OnTxtRecord(const ResourceData & data, BytesRange packetRange) { { @@ -360,13 +309,24 @@ CHIP_ERROR IncrementalResolver::OnTxtRecord(const ResourceData & data, BytesRang return CHIP_NO_ERROR; } -CHIP_ERROR IncrementalResolver::OnIpAddress(const Inet::IPAddress & addr) +CHIP_ERROR IncrementalResolver::OnIpAddress(Inet::InterfaceId interface, const Inet::IPAddress & addr) { if (mCommonResolutionData.numIPs >= ArraySize(mCommonResolutionData.ipAddress)) { return CHIP_ERROR_NO_MEMORY; } + if (!mCommonResolutionData.interfaceId.IsPresent()) + { + mCommonResolutionData.interfaceId = interface; + } + else if (mCommonResolutionData.interfaceId != interface) + { + // IP addresses received from multiple packets over different interfaces. + // Processing is assumed per single interface. + return CHIP_ERROR_INVALID_ARGUMENT; + } + mCommonResolutionData.ipAddress[mCommonResolutionData.numIPs++] = addr; return CHIP_NO_ERROR; } diff --git a/src/lib/dnssd/IncrementalResolve.h b/src/lib/dnssd/IncrementalResolve.h index e0a69cfd8826e4..eff4f04fd28b72 100644 --- a/src/lib/dnssd/IncrementalResolve.h +++ b/src/lib/dnssd/IncrementalResolve.h @@ -106,9 +106,11 @@ class IncrementalResolver /// Providing a data that is not relevant to the current parser is not considered and error, /// however if the resource fails parsing completely an error will be returned. /// - /// [data] represents the record and [packetRange] represents the range of valid bytes within - /// the packet for the purpose of QName parsing - CHIP_ERROR OnRecord(const mdns::Minimal::ResourceData & data, mdns::Minimal::BytesRange packetRange); + /// + /// [data] represents the record received via [interface] and [packetRange] represents the range + /// of valid bytes within the packet for the purpose of QName parsing + CHIP_ERROR OnRecord(Inet::InterfaceId interface, const mdns::Minimal::ResourceData & data, + mdns::Minimal::BytesRange packetRange); /// Return what additional data is required until the object can be extracted /// @@ -116,12 +118,18 @@ class IncrementalResolver /// to be processed. RequiredInformationFlags GetMissingRequiredInformation() const; - /// Fetch the server name set by `InitializeParsing` + /// Fetch the target host name set by `InitializeParsing` /// /// VALIDITY: Data references internal storage of this object and is valid as long /// as this object is valid and InitializeParsing is not called again. mdns::Minimal::SerializedQNameIterator GetTargetHostName() const { return mTargetHostName.Get(); } + /// Fetch the record name set by `InitializeParsing`. + /// + /// VALIDITY: Data references internal storage of this object and is valid as long + /// as this object is valid and InitializeParsing is not called again. + mdns::Minimal::SerializedQNameIterator GetRecordName() const { return mRecordName.Get(); } + /// Take the current value of the object and clear it once returned. /// /// Object must be in `IsActiveCommissionParse()` for this to succeed. @@ -146,11 +154,6 @@ class IncrementalResolver } private: - /// Notify that a PTR record can be parsed. - /// - /// Input data MUST have GetType() == QType::PTR - CHIP_ERROR OnPtrRecord(const mdns::Minimal::ResourceData & data, mdns::Minimal::BytesRange packetRange); - /// Notify that a PTR record can be parsed. /// /// Input data MUST have GetType() == QType::TXT @@ -162,7 +165,7 @@ class IncrementalResolver /// addresses. /// /// Prerequisite: IP address belongs to the right nost name - CHIP_ERROR OnIpAddress(const Inet::IPAddress & addr); + CHIP_ERROR OnIpAddress(Inet::InterfaceId interface, const Inet::IPAddress & addr); using ParsedRecordSpecificData = Variant; diff --git a/src/lib/dnssd/ResolverProxy.h b/src/lib/dnssd/ResolverProxy.h index 57690c2dbb7920..7705ff0a1b2202 100644 --- a/src/lib/dnssd/ResolverProxy.h +++ b/src/lib/dnssd/ResolverProxy.h @@ -79,6 +79,7 @@ class ResolverProxy : public Resolver { public: ResolverProxy() {} + ~ResolverProxy() override; // Resolver interface. CHIP_ERROR Init(Inet::EndPointManager * udpEndPoint = nullptr) override @@ -115,7 +116,10 @@ class ResolverProxy : public Resolver } else { - ChipLogProgress(Discovery, "Delaying proxy of operational discovery: missing delegate"); + if (delegate != nullptr) + { + ChipLogProgress(Discovery, "Delaying proxy of operational discovery: missing delegate"); + } mPreInitOperationalDelegate = delegate; } } @@ -128,7 +132,10 @@ class ResolverProxy : public Resolver } else { - ChipLogError(Discovery, "Delaying proxy of commissioning discovery: missing delegate"); + if (delegate != nullptr) + { + ChipLogError(Discovery, "Delaying proxy of commissioning discovery: missing delegate"); + } mPreInitCommissioningDelegate = delegate; } } diff --git a/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp b/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp index 6d2e40f07c45dc..30500a6c800251 100644 --- a/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp +++ b/src/lib/dnssd/Resolver_ImplMinimalMdns.cpp @@ -21,10 +21,10 @@ #include #include +#include #include #include #include -#include #include #include #include @@ -41,338 +41,212 @@ namespace chip { namespace Dnssd { namespace { -const ByteSpan GetSpan(const mdns::Minimal::BytesRange & range) -{ - return ByteSpan(range.Start(), range.Size()); -} - -template -class TxtRecordDelegateImpl : public mdns::Minimal::TxtRecordDelegate -{ -public: - explicit TxtRecordDelegateImpl(NodeData & nodeData) : mNodeData(nodeData) {} - void OnRecord(const mdns::Minimal::BytesRange & name, const mdns::Minimal::BytesRange & value) override - { - FillNodeDataFromTxt(GetSpan(name), GetSpan(value), mNodeData); - } - -private: - NodeData & mNodeData; -}; - constexpr size_t kMdnsMaxPacketSize = 1024; constexpr uint16_t kMdnsPort = 5353; using namespace mdns::Minimal; -class PacketDataReporter : public ParserDelegate +/// Handles processing of minmdns packet data. +/// +/// Can process multiple incremental resolves based on SRV data and allows +/// retrieval of pending (e.g. to ask for AAAA) and complete data items. +/// +class PacketParser : private ParserDelegate { public: - PacketDataReporter(OperationalResolveDelegate * opDelegate, CommissioningResolveDelegate * commissionDelegate, - chip::Inet::InterfaceId interfaceId, DiscoveryType discoveryType, const BytesRange & packet) : - mOperationalDelegate(opDelegate), - mCommissioningDelegate(commissionDelegate), mDiscoveryType(discoveryType), mPacketRange(packet) - { - mInterfaceId = interfaceId; - } + PacketParser(ActiveResolveAttempts & activeResolves) : mActiveResolves(activeResolves) {} - // ParserDelegate implementation + /// Goes through the given SRV records within a response packet + /// and sets up data resolution + void ParseSrvRecords(const BytesRange & packet); + + /// Goes through non-SRV records and feeds them through the initialized + /// SRV record parsing. + /// + /// Must be called AFTER ParseSrvRecords has been called. + void ParseNonSrvRecords(Inet::InterfaceId interface, const BytesRange & packet); + + IncrementalResolver * ResolverBegin() { return mResolvers; } + IncrementalResolver * ResolverEnd() { return mResolvers + kMinMdnsNumParallelResolvers; } +private: + // ParserDelegate implementation void OnHeader(ConstHeaderRef & header) override; void OnQuery(const QueryData & data) override; void OnResource(ResourceType type, const ResourceData & data) override; - // Called after ParsePacket is complete to send final notifications to the delegate. - // Used to ensure all the available IP addresses are attached before completion. - void OnComplete(ActiveResolveAttempts & activeAttempts); + /// Called IFF data is of SRV type and we are in SRV initialization state + /// + /// Initializes a resolver with the given SRV content as long as + /// inactive resolvers exist. + void ParseSRVResource(const ResourceData & data); -private: - OperationalResolveDelegate * mOperationalDelegate; - CommissioningResolveDelegate * mCommissioningDelegate; - DiscoveryType mDiscoveryType; - ResolvedNodeData mNodeData; - DiscoveredNodeData mDiscoveredNodeData; - chip::Inet::InterfaceId mInterfaceId; - BytesRange mPacketRange; + /// Called IFF parsing state is in RecordParsing + /// + /// Forwards the resource to all active resolvers. + void ParseResource(const ResourceData & data); - bool mValid = false; - bool mHasNodePort = false; - bool mHasIP = false; + enum class RecordParsingState + { + kIdle, + kSrvInitialization, + kRecordParsing, + }; - void OnCommissionableNodeSrvRecord(SerializedQNameIterator name, const SrvRecord & srv); - void OnOperationalSrvRecord(SerializedQNameIterator name, const SrvRecord & srv); + static constexpr size_t kMinMdnsNumParallelResolvers = CHIP_CONFIG_MINMDNS_MAX_PARALLEL_RESOLVES; - /// Handle processing of a newly received IP address - /// - /// Will place the given [addr] into the address list of [resolutionData] assuming that - /// there is enough space for that. - void OnNodeIPAddress(CommonResolutionData & resolutionData, const chip::Inet::IPAddress & addr); -}; + // Individual parse set + bool mIsResponse = false; + Inet::InterfaceId mInterfaceId = Inet::InterfaceId::Null(); + BytesRange mPacketRange; + RecordParsingState mParsingState = RecordParsingState::kIdle; -void PacketDataReporter::OnQuery(const QueryData & data) -{ - // Ignore queries: - // - unicast answers will include the corresponding query in the answer - // packet, however that is not interesting for the resolver. -} + // resolvers kept between parse steps + ActiveResolveAttempts & mActiveResolves; + IncrementalResolver mResolvers[kMinMdnsNumParallelResolvers]; +}; -void PacketDataReporter::OnHeader(ConstHeaderRef & header) +void PacketParser::OnHeader(ConstHeaderRef & header) { - mValid = header.GetFlags().IsResponse(); - mHasIP = false; // will need to get at least one valid IP eventually - mHasNodePort = false; // also need node-port which we do not have yet + mIsResponse = header.GetFlags().IsResponse(); +#ifdef MINMDNS_RESOLVER_OVERLY_VERBOSE if (header.GetFlags().IsTruncated()) { -#ifdef MINMDNS_RESOLVER_OVERLY_VERBOSE // MinMdns does not cache data, so receiving piecewise data does not work ChipLogError(Discovery, "Truncated responses not supported for address resolution"); -#endif } +#endif } -void PacketDataReporter::OnOperationalSrvRecord(SerializedQNameIterator name, const SrvRecord & srv) +void PacketParser::OnQuery(const QueryData & data) { - mdns::Minimal::SerializedQNameIterator it = srv.GetName(); - if (it.Next()) - { - Platform::CopyString(mNodeData.resolutionData.hostName, it.Value()); - } - - if (!name.Next()) - { -#ifdef MINMDNS_RESOLVER_OVERLY_VERBOSE - ChipLogError(Discovery, "mDNS packet is missing a valid server name"); -#endif - return; - } + // Ignore queries: + // - unicast answers will include the corresponding query in the answer + // packet, however that is not interesting for the resolver. +} - if (ExtractIdFromInstanceName(name.Value(), &mNodeData.operationalData.peerId) != CHIP_NO_ERROR) +void PacketParser::OnResource(ResourceType type, const ResourceData & data) +{ + if (!mIsResponse) { - ChipLogError(Discovery, "Failed to parse peer id from %s", name.Value()); return; } - mNodeData.resolutionData.port = srv.GetPort(); - mHasNodePort = true; -} - -void PacketDataReporter::OnCommissionableNodeSrvRecord(SerializedQNameIterator name, const SrvRecord & srv) -{ - // Host name is the first part of the qname - mdns::Minimal::SerializedQNameIterator it = srv.GetName(); - if (it.Next()) + switch (mParsingState) { - Platform::CopyString(mDiscoveredNodeData.resolutionData.hostName, it.Value()); + case RecordParsingState::kSrvInitialization: { + if (data.GetType() != QType::SRV) + { + return; + } + ParseSRVResource(data); + break; } - if (name.Next()) - { - strncpy(mDiscoveredNodeData.commissionData.instanceName, name.Value(), sizeof(CommissionNodeData::instanceName)); + case RecordParsingState::kRecordParsing: + ParseResource(data); + break; + case RecordParsingState::kIdle: + ChipLogError(Discovery, "Illegal state: received DNSSD resource while IDLE"); + break; } - mDiscoveredNodeData.resolutionData.port = srv.GetPort(); - mHasNodePort = true; } -void PacketDataReporter::OnNodeIPAddress(CommonResolutionData & resolutionData, const chip::Inet::IPAddress & addr) +void PacketParser::ParseResource(const ResourceData & data) { - // TODO: should validate that the IP address we receive belongs to the - // server associated with the SRV record. - // - // This code assumes that all entries in the mDNS packet relate to the - // same entity. This may not be correct if multiple servers are reported - // (if multi-admin decides to use unique ports for every ecosystem). - if (resolutionData.numIPs >= CommonResolutionData::kMaxIPAddresses) + for (auto & resolver : mResolvers) { - ChipLogDetail(Discovery, "Number of IP addresses overflow. Discarding extra addresses."); - return; + if (resolver.IsActive()) + { + CHIP_ERROR err = resolver.OnRecord(mInterfaceId, data, mPacketRange); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Discovery, "DNSSD parse error: %" CHIP_ERROR_FORMAT, err.Format()); + } + } } - resolutionData.ipAddress[resolutionData.numIPs++] = addr; - resolutionData.interfaceId = mInterfaceId; - mHasIP = true; -} -bool HasQNamePart(SerializedQNameIterator qname, QNamePart part) -{ - while (qname.Next()) + // Once an IP address is received, stop requesting it. + if (data.GetType() == QType::AAAA) { - if (strcmp(qname.Value(), part) == 0) - { - return true; - } + mActiveResolves.CompleteIpResolution(data.GetName()); } - return false; } -void PacketDataReporter::OnResource(ResourceType type, const ResourceData & data) +void PacketParser::ParseSRVResource(const ResourceData & data) { - if (!mValid) + SrvRecord srv; + if (!srv.Parse(data.GetData(), mPacketRange)) { + ChipLogError(Discovery, "Packet data reporter failed to parse SRV record"); return; } - /// Data content is expected to contain: - /// - A SRV entry that includes the node ID in expected format (fabric + nodeid) - /// - Can extract: fabricid, nodeid, port - /// - References ServerName - /// - Additional records tied to ServerName contain A/AAAA records for IP address data - switch (data.GetType()) + for (auto & resolver : mResolvers) { - case QType::SRV: { - SrvRecord srv; - if (!srv.Parse(data.GetData(), mPacketRange)) + if (resolver.IsActive() && (resolver.GetRecordName() == data.GetName())) { - ChipLogError(Discovery, "Packet data reporter failed to parse SRV record"); - } - else if (mDiscoveryType == DiscoveryType::kOperational) - { - // Ensure this is our record. - // TODO: Fix this comparison which is too loose. - if (HasQNamePart(data.GetName(), kOperationalServiceName)) - { - OnOperationalSrvRecord(data.GetName(), srv); - } - else - { - ChipLogError(Discovery, "Invalid operational srv name: no '%s' part found.", kOperationalServiceName); - } - } - else if (mDiscoveryType == DiscoveryType::kCommissionableNode || mDiscoveryType == DiscoveryType::kCommissionerNode) - { - // TODO: Fix this comparison which is too loose. - if (HasQNamePart(data.GetName(), kCommissionableServiceName) || HasQNamePart(data.GetName(), kCommissionerServiceName)) - { - OnCommissionableNodeSrvRecord(data.GetName(), srv); - } - else - { - ChipLogError(Discovery, "Invalid commision srv name: no '%s' or '%s' part found.", kCommissionableServiceName, - kCommissionerServiceName); - } - } - break; - } - case QType::PTR: { - if (mDiscoveryType == DiscoveryType::kCommissionableNode) - { - SerializedQNameIterator qname; - ParsePtrRecord(data.GetData(), mPacketRange, &qname); - if (qname.Next()) - { - strncpy(mDiscoveredNodeData.commissionData.instanceName, qname.Value(), sizeof(CommissionNodeData::instanceName)); - } + ChipLogDetail(Discovery, "SRV record already actively processed."); + return; } - break; } - case QType::TXT: - if (mDiscoveryType == DiscoveryType::kCommissionableNode || mDiscoveryType == DiscoveryType::kCommissionerNode) - { - TxtRecordDelegateImpl commonDelegate(mDiscoveredNodeData.resolutionData); - ParseTxtRecord(data.GetData(), &commonDelegate); - TxtRecordDelegateImpl commissionDelegate(mDiscoveredNodeData.commissionData); - ParseTxtRecord(data.GetData(), &commissionDelegate); - } - else if (mDiscoveryType == DiscoveryType::kOperational) - { - TxtRecordDelegateImpl commonDelegate(mNodeData.resolutionData); - ParseTxtRecord(data.GetData(), &commonDelegate); - } - break; - case QType::A: { - Inet::IPAddress addr; - if (!ParseARecord(data.GetData(), &addr)) - { - ChipLogError(Discovery, "Packet data reporter failed to parse A record"); - } - else - { - if (mDiscoveryType == DiscoveryType::kOperational) - { - OnNodeIPAddress(mNodeData.resolutionData, addr); - } - else if (mDiscoveryType == DiscoveryType::kCommissionableNode || mDiscoveryType == DiscoveryType::kCommissionerNode) - { - OnNodeIPAddress(mDiscoveredNodeData.resolutionData, addr); - } - } - break; - } - case QType::AAAA: { - Inet::IPAddress addr; - if (!ParseAAAARecord(data.GetData(), &addr)) + for (auto & resolver : mResolvers) + { + if (resolver.IsActive()) { - ChipLogError(Discovery, "Packet data reporter failed to parse AAAA record"); + continue; } - else + + CHIP_ERROR err = resolver.InitializeParsing(data.GetName(), srv); + if (err != CHIP_NO_ERROR) { - if (mDiscoveryType == DiscoveryType::kOperational) - { - OnNodeIPAddress(mNodeData.resolutionData, addr); - } - else if (mDiscoveryType == DiscoveryType::kCommissionableNode || mDiscoveryType == DiscoveryType::kCommissionerNode) - { - OnNodeIPAddress(mDiscoveredNodeData.resolutionData, addr); - } + // Receiving records that we do not need to parse is normal: + // MinMDNS may receive all DNSSD packets on the network, only + // interested in a subset that is matter-specific +#ifdef MINMDNS_RESOLVER_OVERLY_VERBOSE + ChipLogError(Discovery, "Could not start SRV record processing: %" CHIP_ERROR_FORMAT, err.Format()); +#endif } - break; - } - default: - break; + + // Done finding an inactive resolver and attempting to use it. + return; } + + ChipLogError(Discovery, "Insufficient parsers to process all SRV entries."); } -void PacketDataReporter::OnComplete(ActiveResolveAttempts & activeAttempts) +void PacketParser::ParseSrvRecords(const BytesRange & packet) { - if (mDiscoveryType == DiscoveryType::kCommissionableNode || mDiscoveryType == DiscoveryType::kCommissionerNode) - { - if (!mDiscoveredNodeData.resolutionData.IsValid()) - { - ChipLogError(Discovery, "Discovered node data is not valid. Commissioning discovery not complete."); - return; - } + mParsingState = RecordParsingState::kSrvInitialization; + mPacketRange = packet; - activeAttempts.Complete(mDiscoveredNodeData); - if (mCommissioningDelegate != nullptr) - { - mCommissioningDelegate->OnNodeDiscovered(mDiscoveredNodeData); - } - else - { - ChipLogError(Discovery, "No delegate to report commissioning node discovery"); - } - } - else if (mDiscoveryType == DiscoveryType::kOperational) + if (!ParsePacket(packet, this)) { - if (!mHasIP) - { - ChipLogError(Discovery, "Operational discovery has no valid ip address. Resolve not complete."); - return; - } + ChipLogError(Discovery, "DNSSD packet parsing failed (for SRV records)"); + } - if (!mHasNodePort) - { - ChipLogError(Discovery, "Operational discovery has no valid node/port. Resolve not complete."); - return; - } + mParsingState = RecordParsingState::kIdle; +} - activeAttempts.Complete(mNodeData.operationalData.peerId); - mNodeData.LogNodeIdResolved(); +void PacketParser::ParseNonSrvRecords(Inet::InterfaceId interface, const BytesRange & packet) +{ + mParsingState = RecordParsingState::kRecordParsing; + mPacketRange = packet; + mInterfaceId = interface; - if (mOperationalDelegate != nullptr) - { - mOperationalDelegate->OnOperationalNodeResolved(mNodeData); - } - else - { - ChipLogError(Discovery, "No delegate to report operational node discovery"); - } + if (!ParsePacket(packet, this)) + { + ChipLogError(Discovery, "DNSSD packet parsing failed (for non-srv records)"); } + + mParsingState = RecordParsingState::kIdle; } class MinMdnsResolver : public Resolver, public MdnsPacketDelegate { public: - MinMdnsResolver() : mActiveResolves(&chip::System::SystemClock()) + MinMdnsResolver() : mActiveResolves(&chip::System::SystemClock()), mPacketParser(mActiveResolves) { GlobalMinimalMdnsServer::Instance().SetResponseDelegate(this); } @@ -392,18 +266,29 @@ class MinMdnsResolver : public Resolver, public MdnsPacketDelegate private: OperationalResolveDelegate * mOperationalDelegate = nullptr; CommissioningResolveDelegate * mCommissioningDelegate = nullptr; - DiscoveryType mDiscoveryType = DiscoveryType::kUnknown; System::Layer * mSystemLayer = nullptr; ActiveResolveAttempts mActiveResolves; + PacketParser mPacketParser; + + void ScheduleIpAddressResolve(SerializedQNameIterator hostName); - CHIP_ERROR SendPendingResolveQueries(); - CHIP_ERROR SendPendingBrowseQueries(); CHIP_ERROR SendAllPendingQueries(); CHIP_ERROR ScheduleRetries(); + /// Prepare a query for the given schedule attempt + CHIP_ERROR BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt & attempt); + + /// Prepare a query for specific resolve types + CHIP_ERROR BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt::Browse & data, bool firstSend); + CHIP_ERROR BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt::Resolve & data, bool firstSend); + CHIP_ERROR BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt::IpResolve & data, bool firstSend); + + /// Clear any incremental resolver that is not waiting for a AAAA address. + void ExpireIncrementalResolvers(); + void AdvancePendingResolverStates(); + static void RetryCallback(System::Layer *, void * self); - CHIP_ERROR SendQuery(mdns::Minimal::FullQName qname, mdns::Minimal::QType type, bool unicastResponse); CHIP_ERROR BrowseNodes(DiscoveryType type, DiscoveryFilter subtype); template mdns::Minimal::FullQName CheckAndAllocateQName(Args &&... parts) @@ -419,26 +304,102 @@ class MinMdnsResolver : public Resolver, public MdnsPacketDelegate char qnameStorage[kMaxQnameSize]; }; -void MinMdnsResolver::OnMdnsPacketData(const BytesRange & data, const chip::Inet::IPPacketInfo * info) +void MinMdnsResolver::ScheduleIpAddressResolve(SerializedQNameIterator hostName) { - if ((mOperationalDelegate == nullptr) && (mCommissioningDelegate == nullptr)) + HeapQName target(hostName); + if (!target.IsOk()) { + ChipLogError(Discovery, "Memory allocation error for IP address resolution"); return; } + mActiveResolves.MarkPending(ActiveResolveAttempts::ScheduledAttempt::IpResolve(std::move(target))); +} - PacketDataReporter reporter(mOperationalDelegate, mCommissioningDelegate, info->Interface, mDiscoveryType, data); - - if (!ParsePacket(data, &reporter)) - { - ChipLogError(Discovery, "Failed to parse received mDNS packet"); - } - else +void MinMdnsResolver::AdvancePendingResolverStates() +{ + for (IncrementalResolver * resolver = mPacketParser.ResolverBegin(); resolver != mPacketParser.ResolverEnd(); resolver++) { - reporter.OnComplete(mActiveResolves); - ScheduleRetries(); + if (!resolver->IsActive()) + { + continue; + } + + IncrementalResolver::RequiredInformationFlags missing = resolver->GetMissingRequiredInformation(); + + if (missing.Has(IncrementalResolver::RequiredInformationBitFlags::kIpAddress)) + { + ScheduleIpAddressResolve(resolver->GetTargetHostName()); + continue; + } + + if (missing.HasAny()) + { + // Expect either IP missing (ask for it) or done. Anything else is not handled + ChipLogError(Discovery, "Unexpected state: cannot advance resolver with missing information"); + resolver->ResetToInactive(); + continue; + } + + // SUCCESS. Call the delegates + if (resolver->IsActiveCommissionParse()) + { + DiscoveredNodeData nodeData; + + CHIP_ERROR err = resolver->Take(nodeData); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Discovery, "Failed to take discovery result: %" CHIP_ERROR_FORMAT, err.Format()); + } + + mActiveResolves.Complete(nodeData); + if (mCommissioningDelegate != nullptr) + { + mCommissioningDelegate->OnNodeDiscovered(nodeData); + } + else + { + ChipLogError(Discovery, "No delegate to report commissioning node discovery"); + } + } + else if (resolver->IsActiveOperationalParse()) + { + ResolvedNodeData nodeData; + + CHIP_ERROR err = resolver->Take(nodeData); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Discovery, "Failed to take discovery result: %" CHIP_ERROR_FORMAT, err.Format()); + } + + mActiveResolves.Complete(nodeData.operationalData.peerId); + if (mOperationalDelegate != nullptr) + { + mOperationalDelegate->OnOperationalNodeResolved(nodeData); + } + else + { + ChipLogError(Discovery, "No delegate to report operational node discovery"); + } + } + else + { + ChipLogError(Discovery, "Unexpected state: record type unknown"); + resolver->ResetToInactive(); + } } } +void MinMdnsResolver::OnMdnsPacketData(const BytesRange & data, const chip::Inet::IPPacketInfo * info) +{ + // Fill up any relevant data + mPacketParser.ParseSrvRecords(data); + mPacketParser.ParseNonSrvRecords(info->Interface, data); + + AdvancePendingResolverStates(); + + ScheduleRetries(); +} + CHIP_ERROR MinMdnsResolver::Init(chip::Inet::EndPointManager * udpEndPointManager) { /// Note: we do not double-check the port as we assume the APP will always use @@ -458,139 +419,205 @@ void MinMdnsResolver::Shutdown() GlobalMinimalMdnsServer::Instance().ShutdownServer(); } -CHIP_ERROR MinMdnsResolver::SendQuery(mdns::Minimal::FullQName qname, mdns::Minimal::QType type, bool unicastResponse) +CHIP_ERROR MinMdnsResolver::BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt::Browse & data, + bool firstSend) { - System::PacketBufferHandle buffer = System::PacketBufferHandle::New(kMdnsMaxPacketSize); - ReturnErrorCodeIf(buffer.IsNull(), CHIP_ERROR_NO_MEMORY); + mdns::Minimal::FullQName qname; + + switch (data.type) + { + case DiscoveryType::kOperational: + qname = CheckAndAllocateQName(kOperationalServiceName, kOperationalProtocol, kLocalDomain); + break; + case DiscoveryType::kCommissionableNode: + if (data.filter.type == DiscoveryFilterType::kNone) + { + qname = CheckAndAllocateQName(kCommissionableServiceName, kCommissionProtocol, kLocalDomain); + } + else if (data.filter.type == DiscoveryFilterType::kInstanceName) + { + qname = CheckAndAllocateQName(data.filter.instanceName, kCommissionableServiceName, kCommissionProtocol, kLocalDomain); + } + else + { + char subtypeStr[Common::kSubTypeMaxLength + 1]; + ReturnErrorOnFailure(MakeServiceSubtype(subtypeStr, sizeof(subtypeStr), data.filter)); + qname = CheckAndAllocateQName(subtypeStr, kSubtypeServiceNamePart, kCommissionableServiceName, kCommissionProtocol, + kLocalDomain); + } + break; + case DiscoveryType::kCommissionerNode: + if (data.filter.type == DiscoveryFilterType::kNone) + { + qname = CheckAndAllocateQName(kCommissionerServiceName, kCommissionProtocol, kLocalDomain); + } + else + { + char subtypeStr[Common::kSubTypeMaxLength + 1]; + ReturnErrorOnFailure(MakeServiceSubtype(subtypeStr, sizeof(subtypeStr), data.filter)); + qname = CheckAndAllocateQName(subtypeStr, kSubtypeServiceNamePart, kCommissionerServiceName, kCommissionProtocol, + kLocalDomain); + } + break; + case DiscoveryType::kUnknown: + break; + } - QueryBuilder builder(std::move(buffer)); - builder.Header().SetMessageId(0); + ReturnErrorCodeIf(!qname.nameCount, CHIP_ERROR_NO_MEMORY); mdns::Minimal::Query query(qname); - query.SetType(type).SetClass(mdns::Minimal::QClass::IN); - query.SetAnswerViaUnicast(unicastResponse); - + query + .SetClass(QClass::IN) // + .SetType(QType::ANY) // + .SetAnswerViaUnicast(firstSend) // + ; builder.AddQuery(query); - ReturnErrorCodeIf(!builder.Ok(), CHIP_ERROR_INTERNAL); - - if (unicastResponse) - { - ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastUnicastQuery(builder.ReleasePacket(), kMdnsPort)); - } - else - { - ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastSend(builder.ReleasePacket(), kMdnsPort)); - } return CHIP_NO_ERROR; } -CHIP_ERROR MinMdnsResolver::SendAllPendingQueries() +CHIP_ERROR MinMdnsResolver::BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt::Resolve & data, + bool firstSend) { - CHIP_ERROR browseErr = SendPendingBrowseQueries(); - CHIP_ERROR resolveErr = SendPendingResolveQueries(); - return resolveErr == CHIP_NO_ERROR ? browseErr : resolveErr; -} + char nameBuffer[kMaxOperationalServiceNameSize] = ""; -CHIP_ERROR MinMdnsResolver::FindCommissionableNodes(DiscoveryFilter filter) -{ - return BrowseNodes(DiscoveryType::kCommissionableNode, filter); + // Node and fabricid are encoded in server names. + ReturnErrorOnFailure(MakeInstanceName(nameBuffer, sizeof(nameBuffer), data.peerId)); + + const char * instanceQName[] = { nameBuffer, kOperationalServiceName, kOperationalProtocol, kLocalDomain }; + Query query(instanceQName); + + query + .SetClass(QClass::IN) // + .SetType(QType::ANY) // + .SetAnswerViaUnicast(firstSend) // + ; + + builder.AddQuery(query); + + return CHIP_NO_ERROR; } -CHIP_ERROR MinMdnsResolver::FindCommissioners(DiscoveryFilter filter) +CHIP_ERROR MinMdnsResolver::BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt::IpResolve & data, + bool firstSend) { - return BrowseNodes(DiscoveryType::kCommissionerNode, filter); + + Query query(data.hostName.Content()); + + query + .SetClass(QClass::IN) // + .SetType(QType::AAAA) // + .SetAnswerViaUnicast(firstSend) // + ; + + builder.AddQuery(query); + + return CHIP_NO_ERROR; } -CHIP_ERROR MinMdnsResolver::BrowseNodes(DiscoveryType type, DiscoveryFilter filter) +CHIP_ERROR MinMdnsResolver::BuildQuery(QueryBuilder & builder, const ActiveResolveAttempts::ScheduledAttempt & attempt) { - mDiscoveryType = type; - mActiveResolves.MarkPending(filter, type); + if (attempt.IsResolve()) + { + ReturnErrorOnFailure(BuildQuery(builder, attempt.ResolveData(), attempt.firstSend)); + } + else if (attempt.IsBrowse()) + { + ReturnErrorOnFailure(BuildQuery(builder, attempt.BrowseData(), attempt.firstSend)); + } + else if (attempt.IsIpResolve()) + { + ReturnErrorOnFailure(BuildQuery(builder, attempt.IpResolveData(), attempt.firstSend)); + } + else + { + return CHIP_ERROR_INVALID_ARGUMENT; + } - return SendPendingBrowseQueries(); + ReturnErrorCodeIf(!builder.Ok(), CHIP_ERROR_INTERNAL); + return CHIP_NO_ERROR; } -CHIP_ERROR MinMdnsResolver::SendPendingBrowseQueries() +CHIP_ERROR MinMdnsResolver::SendAllPendingQueries() { - CHIP_ERROR returnErr = CHIP_NO_ERROR; while (true) { - Optional attempt = mActiveResolves.NextScheduled(); + Optional resolve = mActiveResolves.NextScheduled(); - if (!attempt.HasValue()) + if (!resolve.HasValue()) { break; } - if (!attempt.Value().IsBrowse()) - { - continue; - } - mdns::Minimal::FullQName qname; + System::PacketBufferHandle buffer = System::PacketBufferHandle::New(kMdnsMaxPacketSize); + ReturnErrorCodeIf(buffer.IsNull(), CHIP_ERROR_NO_MEMORY); + + QueryBuilder builder(std::move(buffer)); + builder.Header().SetMessageId(0); + + ReturnErrorOnFailure(BuildQuery(builder, resolve.Value())); - switch (attempt.Value().BrowseData().type) + if (resolve.Value().firstSend) { - case DiscoveryType::kOperational: - qname = CheckAndAllocateQName(kOperationalServiceName, kOperationalProtocol, kLocalDomain); - break; - case DiscoveryType::kCommissionableNode: - if (attempt.Value().BrowseData().filter.type == DiscoveryFilterType::kNone) - { - qname = CheckAndAllocateQName(kCommissionableServiceName, kCommissionProtocol, kLocalDomain); - } - else if (attempt.Value().BrowseData().filter.type == DiscoveryFilterType::kInstanceName) - { - qname = CheckAndAllocateQName(attempt.Value().BrowseData().filter.instanceName, kCommissionableServiceName, - kCommissionProtocol, kLocalDomain); - } - else - { - char subtypeStr[Common::kSubTypeMaxLength + 1]; - ReturnErrorOnFailure(MakeServiceSubtype(subtypeStr, sizeof(subtypeStr), attempt.Value().BrowseData().filter)); - qname = CheckAndAllocateQName(subtypeStr, kSubtypeServiceNamePart, kCommissionableServiceName, kCommissionProtocol, - kLocalDomain); - } - break; - case DiscoveryType::kCommissionerNode: - if (attempt.Value().BrowseData().filter.type == DiscoveryFilterType::kNone) - { - qname = CheckAndAllocateQName(kCommissionerServiceName, kCommissionProtocol, kLocalDomain); - } - else - { - char subtypeStr[Common::kSubTypeMaxLength + 1]; - ReturnErrorOnFailure(MakeServiceSubtype(subtypeStr, sizeof(subtypeStr), attempt.Value().BrowseData().filter)); - qname = CheckAndAllocateQName(subtypeStr, kSubtypeServiceNamePart, kCommissionerServiceName, kCommissionProtocol, - kLocalDomain); - } - break; - case DiscoveryType::kUnknown: - break; + ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastUnicastQuery(builder.ReleasePacket(), kMdnsPort)); } - if (!qname.nameCount) + else { - return CHIP_ERROR_NO_MEMORY; + ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastSend(builder.ReleasePacket(), kMdnsPort)); } + } - bool unicastResponse = attempt.Value().firstSend; + ExpireIncrementalResolvers(); - CHIP_ERROR err = SendQuery(qname, mdns::Minimal::QType::ANY, unicastResponse); - if (err != CHIP_NO_ERROR) + return ScheduleRetries(); +} + +void MinMdnsResolver::ExpireIncrementalResolvers() +{ + // once all queries are sent, if any SRV cannot receive AAAA addresses, expire it + for (IncrementalResolver * resolver = mPacketParser.ResolverBegin(); resolver != mPacketParser.ResolverEnd(); resolver++) + { + if (!resolver->IsActive()) { - // We want to continue sending, but we do want this error returned - returnErr = err; + continue; + } + + IncrementalResolver::RequiredInformationFlags missing = resolver->GetMissingRequiredInformation(); + if (missing.Has(IncrementalResolver::RequiredInformationBitFlags::kIpAddress)) + { + if (mActiveResolves.IsWaitingForIpResolutionFor(resolver->GetTargetHostName())) + { + continue; + } } + + // mark as expired: not waiting for anything + resolver->ResetToInactive(); } - ReturnErrorOnFailure(ScheduleRetries()); - return returnErr; +} + +CHIP_ERROR MinMdnsResolver::FindCommissionableNodes(DiscoveryFilter filter) +{ + return BrowseNodes(DiscoveryType::kCommissionableNode, filter); +} + +CHIP_ERROR MinMdnsResolver::FindCommissioners(DiscoveryFilter filter) +{ + return BrowseNodes(DiscoveryType::kCommissionerNode, filter); +} + +CHIP_ERROR MinMdnsResolver::BrowseNodes(DiscoveryType type, DiscoveryFilter filter) +{ + mActiveResolves.MarkPending(filter, type); + + return SendAllPendingQueries(); } CHIP_ERROR MinMdnsResolver::ResolveNodeId(const PeerId & peerId, Inet::IPAddressType type) { - mDiscoveryType = DiscoveryType::kOperational; mActiveResolves.MarkPending(peerId); - return SendPendingResolveQueries(); + return SendAllPendingQueries(); } CHIP_ERROR MinMdnsResolver::ScheduleRetries() @@ -613,73 +640,6 @@ void MinMdnsResolver::RetryCallback(System::Layer *, void * self) reinterpret_cast(self)->SendAllPendingQueries(); } -CHIP_ERROR MinMdnsResolver::SendPendingResolveQueries() -{ - while (true) - { - Optional resolve = mActiveResolves.NextScheduled(); - - if (!resolve.HasValue()) - { - break; - } - if (!resolve.Value().IsResolve()) - { - continue; - } - - System::PacketBufferHandle buffer = System::PacketBufferHandle::New(kMdnsMaxPacketSize); - ReturnErrorCodeIf(buffer.IsNull(), CHIP_ERROR_NO_MEMORY); - - QueryBuilder builder(std::move(buffer)); - builder.Header().SetMessageId(0); - - { - char nameBuffer[kMaxOperationalServiceNameSize] = ""; - - // Node and fabricid are encoded in server names. - ReturnErrorOnFailure(MakeInstanceName(nameBuffer, sizeof(nameBuffer), resolve.Value().ResolveData().peerId)); - - const char * instanceQName[] = { nameBuffer, kOperationalServiceName, kOperationalProtocol, kLocalDomain }; - Query query(instanceQName); - - query - .SetClass(QClass::IN) // - .SetType(QType::ANY) // - .SetAnswerViaUnicast(resolve.Value().firstSend) // - ; - - // NOTE: type above is NOT A or AAAA because the name searched for is - // a SRV record. The layout is: - // SRV -> hostname - // Hostname -> A - // Hostname -> AAAA - // - // Query is sent for ANY and expectation is to receive A/AAAA records - // in the additional section of the reply. - // - // Sending a A/AAAA query will return no results - // Sending a SRV query will return the srv only and an additional query - // would be needed to resolve the host name to an IP address - - builder.AddQuery(query); - } - - ReturnErrorCodeIf(!builder.Ok(), CHIP_ERROR_INTERNAL); - - if (resolve.Value().firstSend) - { - ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastUnicastQuery(builder.ReleasePacket(), kMdnsPort)); - } - else - { - ReturnErrorOnFailure(GlobalMinimalMdnsServer::Server().BroadcastSend(builder.ReleasePacket(), kMdnsPort)); - } - } - - return ScheduleRetries(); -} - MinMdnsResolver gResolver; } // namespace @@ -689,6 +649,14 @@ Resolver & chip::Dnssd::Resolver::Instance() return gResolver; } +ResolverProxy::~ResolverProxy() +{ + // TODO: this is a hack: resolver proxies used for commissionable discovery + // and they don't interact well with each other. + gResolver.SetCommissioningDelegate(nullptr); + Shutdown(); +} + // Minimal implementation does not support associating a context to a request (while platforms implementations do). So keep // updating the delegate that ends up being used by the server by calling 'SetOperationalDelegate'. // This effectively allow minimal to have multiple controllers issuing requests as long the requests are serialized, but diff --git a/src/lib/dnssd/Resolver_ImplNone.cpp b/src/lib/dnssd/Resolver_ImplNone.cpp index bf4de55100bf45..f1315be776e76a 100644 --- a/src/lib/dnssd/Resolver_ImplNone.cpp +++ b/src/lib/dnssd/Resolver_ImplNone.cpp @@ -50,6 +50,11 @@ Resolver & chip::Dnssd::Resolver::Instance() return gResolver; } +ResolverProxy::~ResolverProxy() +{ + Shutdown(); +} + CHIP_ERROR ResolverProxy::ResolveNodeId(const PeerId & peerId, Inet::IPAddressType type) { return CHIP_ERROR_NOT_IMPLEMENTED; diff --git a/src/lib/dnssd/minimal_mdns/core/HeapQName.h b/src/lib/dnssd/minimal_mdns/core/HeapQName.h new file mode 100644 index 00000000000000..eb9662d7fea274 --- /dev/null +++ b/src/lib/dnssd/minimal_mdns/core/HeapQName.h @@ -0,0 +1,166 @@ +/* + * + * Copyright (c) 2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace mdns { +namespace Minimal { + +/// Contructs a FullQName from a SerializedNameIterator +/// +/// Generally a conversion from an iterator to a `const char *[]` +/// using heap for underlying storage of the data. +class HeapQName +{ +public: + HeapQName() {} + HeapQName(SerializedQNameIterator name) + { + // Storage is: + // - separate pointers into mElementPointers + // - allocated pointers inside that + mElementCount = 0; + + SerializedQNameIterator it = name; + while (it.Next()) + { + // Count all elements + mElementCount++; + } + + if (!it.IsValid()) + { + return; + } + + mElementPointers.Alloc(mElementCount); + if (!mElementPointers) + { + return; + } + // ensure all set to null since we may need to free + for (size_t i = 0; i < mElementCount; i++) + { + mElementPointers[i] = nullptr; + } + + it = name; + size_t idx = 0; + while (it.Next()) + { + mElementPointers[idx] = chip::Platform::MemoryAllocString(it.Value(), strlen(it.Value())); + if (!mElementPointers[idx]) + { + return; + } + idx++; + } + mIsOk = true; + } + + HeapQName(const HeapQName & other) { *this = other; } + + HeapQName & operator=(const HeapQName & other) + { + Free(); + + if (!other) + { + return *this; // No point in copying the other value + } + + mElementCount = other.mElementCount; + mElementPointers.Alloc(other.mElementCount); + if (!mElementPointers) + { + return *this; + } + + for (size_t i = 0; i < mElementCount; i++) + { + mElementPointers[i] = nullptr; + } + + for (size_t i = 0; i < mElementCount; i++) + { + const char * other_data = other.mElementPointers[i]; + mElementPointers[i] = chip::Platform::MemoryAllocString(other_data, strlen(other_data)); + if (!mElementPointers[i]) + { + return *this; + } + } + mIsOk = true; + return *this; + } + + ~HeapQName() { Free(); } + + bool IsOk() const { return mIsOk; } + + operator bool() const { return IsOk(); } + bool operator!() const { return !IsOk(); } + + /// Returns the contained FullQName. + /// + /// VALIDITY: since this references data inside `this` it is only valid + /// as long as `this` is valid. + FullQName Content() const + { + FullQName result; + + result.names = mElementPointers.Get(); + result.nameCount = mElementCount; + + return result; + } + +private: + void Free() + { + if (!mElementPointers) + { + return; + } + + for (size_t i = 0; i < mElementCount; i++) + { + if (mElementPointers[i] != nullptr) + { + chip::Platform::MemoryFree(mElementPointers[i]); + mElementPointers[i] = nullptr; + } + } + mElementPointers.Free(); + mElementCount = 0; + mIsOk = false; + } + + bool mIsOk = false; + size_t mElementCount = 0; + chip::Platform::ScopedMemoryBuffer mElementPointers; +}; + +} // namespace Minimal +} // namespace mdns diff --git a/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn b/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn index c1f4d872dce74e..098d582e60408f 100644 --- a/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn +++ b/src/lib/dnssd/minimal_mdns/core/tests/BUILD.gn @@ -18,7 +18,7 @@ import("//build_overrides/nlunit_test.gni") import("${chip_root}/build/chip/chip_test_suite.gni") -static_library("support") { +source_set("support") { sources = [ "QNameStrings.h" ] public_deps = [ @@ -32,6 +32,7 @@ chip_test_suite("tests") { test_sources = [ "TestFlatAllocatedQName.cpp", + "TestHeapQName.cpp", "TestQName.cpp", "TestRecordWriter.cpp", ] @@ -39,6 +40,7 @@ chip_test_suite("tests") { cflags = [ "-Wconversion" ] public_deps = [ + ":support", "${chip_root}/src/lib/core", "${chip_root}/src/lib/dnssd/minimal_mdns/core", "${nlunit_test_root}:nlunit-test", diff --git a/src/lib/dnssd/minimal_mdns/core/tests/TestHeapQName.cpp b/src/lib/dnssd/minimal_mdns/core/tests/TestHeapQName.cpp new file mode 100644 index 00000000000000..f77eb2441c948e --- /dev/null +++ b/src/lib/dnssd/minimal_mdns/core/tests/TestHeapQName.cpp @@ -0,0 +1,110 @@ +/* + * + * Copyright (c) 2022 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +namespace { + +using namespace mdns::Minimal; + +void Construction(nlTestSuite * inSuite, void * inContext) +{ + { + + const testing::TestQName<2> kShort({ "some", "test" }); + + HeapQName heapQName(kShort.Serialized()); + + NL_TEST_ASSERT(inSuite, heapQName.IsOk()); + NL_TEST_ASSERT(inSuite, heapQName.Content() == kShort.Full()); + NL_TEST_ASSERT(inSuite, kShort.Serialized() == heapQName.Content()); + } + + { + + const testing::TestQName<5> kLonger({ "these", "are", "more", "elements", "here" }); + + HeapQName heapQName(kLonger.Serialized()); + + NL_TEST_ASSERT(inSuite, heapQName.IsOk()); + NL_TEST_ASSERT(inSuite, heapQName.Content() == kLonger.Full()); + NL_TEST_ASSERT(inSuite, kLonger.Serialized() == heapQName.Content()); + } +} + +void Copying(nlTestSuite * inSuite, void * inContext) +{ + const testing::TestQName<2> kShort({ "some", "test" }); + + HeapQName name1(kShort.Serialized()); + HeapQName name2(name1); + HeapQName name3; + + name3 = name2; + + NL_TEST_ASSERT(inSuite, name1.IsOk()); + NL_TEST_ASSERT(inSuite, name2.IsOk()); + NL_TEST_ASSERT(inSuite, name3.IsOk()); + NL_TEST_ASSERT(inSuite, name1.Content() == name2.Content()); + NL_TEST_ASSERT(inSuite, name1.Content() == name3.Content()); +} + +static const nlTest sTests[] = { // + NL_TEST_DEF("Construction", Construction), // + NL_TEST_DEF("Copying", Copying), // + NL_TEST_SENTINEL() +}; + +int Setup(void * inContext) +{ + CHIP_ERROR error = chip::Platform::MemoryInit(); + if (error != CHIP_NO_ERROR) + return FAILURE; + return SUCCESS; +} + +/** + * Tear down the test suite. + */ +int Teardown(void * inContext) +{ + chip::Platform::MemoryShutdown(); + return SUCCESS; +} + +} // namespace + +int TestHeapQName(void) +{ + nlTestSuite theSuite = { + "HeapQName", + &sTests[0], + &Setup, + &Teardown, + }; + + nlTestRunner(&theSuite, nullptr); + + return (nlTestRunnerStats(&theSuite)); +} + +CHIP_REGISTER_TEST_SUITE(TestHeapQName) diff --git a/src/lib/dnssd/platform/tests/TestPlatform.cpp b/src/lib/dnssd/platform/tests/TestPlatform.cpp index e8498ce2947b8a..d7125f05b8fc41 100644 --- a/src/lib/dnssd/platform/tests/TestPlatform.cpp +++ b/src/lib/dnssd/platform/tests/TestPlatform.cpp @@ -226,6 +226,7 @@ int TestSetup(void * inContext) int TestTeardown(void * inContext) { + DiscoveryImplPlatform::GetInstance().Shutdown(); chip::Platform::MemoryShutdown(); return SUCCESS; } diff --git a/src/lib/dnssd/tests/TestIncrementalResolve.cpp b/src/lib/dnssd/tests/TestIncrementalResolve.cpp index eaa855e119bf1e..005779b4f892ab 100644 --- a/src/lib/dnssd/tests/TestIncrementalResolve.cpp +++ b/src/lib/dnssd/tests/TestIncrementalResolve.cpp @@ -100,7 +100,7 @@ void CallOnRecord(nlTestSuite * inSuite, IncrementalResolver & resolver, const R BytesRange packet(dataBuffer, dataBuffer + sizeof(dataBuffer)); const uint8_t * _ptr = dataBuffer; NL_TEST_ASSERT(inSuite, resource.Parse(packet, &_ptr)); - NL_TEST_ASSERT(inSuite, resolver.OnRecord(resource, packet) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, resolver.OnRecord(chip::Inet::InterfaceId::Null(), resource, packet) == CHIP_NO_ERROR); } void TestStoredServerName(nlTestSuite * inSuite, void * inContext) diff --git a/src/lib/support/ScopedBuffer.h b/src/lib/support/ScopedBuffer.h index 137aafe6c57ac8..ead4c563946e1d 100644 --- a/src/lib/support/ScopedBuffer.h +++ b/src/lib/support/ScopedBuffer.h @@ -142,7 +142,7 @@ class ScopedMemoryBuffer : public Impl::ScopedMemoryBufferBase inline T * Get() { return static_cast(Base::Ptr()); } inline T & operator[](size_t index) { return Get()[index]; } - inline const T * Get() const { return static_cast(Base::Ptr()); } + inline const T * Get() const { return static_cast(Base::Ptr()); } inline const T & operator[](size_t index) const { return Get()[index]; } inline T * Release() { return static_cast(Base::Release()); }