diff --git a/src/vswhere.lib/InstanceSelector.cpp b/src/vswhere.lib/InstanceSelector.cpp index 9c7db44..7437206 100644 --- a/src/vswhere.lib/InstanceSelector.cpp +++ b/src/vswhere.lib/InstanceSelector.cpp @@ -89,29 +89,30 @@ bool InstanceSelector::Less(const ISetupInstancePtr& a, const ISetupInstancePtr& } } -vector InstanceSelector::Select(_In_ IEnumSetupInstances* pEnum) const +vector InstanceSelector::Select(_In_opt_ IEnumSetupInstances* pEnum) const { - _ASSERT(pEnum); - - HRESULT hr = S_OK; - unsigned long celtFetched = 0; - ISetupInstance* pInstances[1] = {}; - vector instances; - do + if (pEnum) { - celtFetched = 0; + HRESULT hr = S_OK; + unsigned long celtFetched = 0; + ISetupInstance* pInstances[1] = {}; - hr = pEnum->Next(1, pInstances, &celtFetched); - if (SUCCEEDED(hr) && celtFetched) + do { - ISetupInstancePtr instance(pInstances[0], false); - if (IsMatch(instance)) + celtFetched = 0; + + hr = pEnum->Next(1, pInstances, &celtFetched); + if (SUCCEEDED(hr) && celtFetched) { - instances.push_back(instance); + ISetupInstancePtr instance(pInstances[0], false); + if (IsMatch(instance)) + { + instances.push_back(instance); + } } - } - } while (SUCCEEDED(hr) && celtFetched); + } while (SUCCEEDED(hr) && celtFetched); + } if (m_args.get_Latest() && 1 < instances.size()) { diff --git a/src/vswhere.lib/InstanceSelector.h b/src/vswhere.lib/InstanceSelector.h index 1ab701b..b3d601c 100644 --- a/src/vswhere.lib/InstanceSelector.h +++ b/src/vswhere.lib/InstanceSelector.h @@ -18,7 +18,7 @@ class InstanceSelector } bool Less(const ISetupInstancePtr& a, const ISetupInstancePtr& b) const; - std::vector Select(_In_ IEnumSetupInstances* pEnum) const; + std::vector Select(_In_opt_ IEnumSetupInstances* pEnum) const; private: static std::wstring GetId(_In_ ISetupPackageReference* pPackageReference); diff --git a/src/vswhere/Program.cpp b/src/vswhere/Program.cpp index ca7fc3c..9101967 100644 --- a/src/vswhere/Program.cpp +++ b/src/vswhere/Program.cpp @@ -7,6 +7,7 @@ using namespace std; +void GetEnumerator(_In_ const CommandArgs& args, _In_ ISetupConfigurationPtr& query, _In_ IEnumSetupInstancesPtr& e); void WriteLogo(_In_ const CommandArgs& args, _In_ wostream& out); int wmain(_In_ int argc, _In_ LPCWSTR argv[]) @@ -30,44 +31,14 @@ int wmain(_In_ int argc, _In_ LPCWSTR argv[]) ISetupConfigurationPtr query; IEnumSetupInstancesPtr e; - auto hr = query.CreateInstance(__uuidof(SetupConfiguration)); - if (FAILED(hr)) - { - if (REGDB_E_CLASSNOTREG == hr) - { - WriteLogo(args, out); - return ERROR_SUCCESS; - } - } - - // If all instances are requested, try to get the proper enumerator; otherwise, fall back to original enumerator. - if (args.get_All()) - { - ISetupConfiguration2Ptr query2; - - hr = query->QueryInterface(&query2); - if (SUCCEEDED(hr)) - { - hr = query2->EnumAllInstances(&e); - if (FAILED(hr)) - { - throw win32_error(hr); - } - } - } - - if (!e) - { - hr = query->EnumInstances(&e); - if (FAILED(hr)) - { - throw win32_error(hr); - } - } + GetEnumerator(args, query, e); // Attempt to get the ISetupHelper. ISetupHelperPtr helper; - query->QueryInterface(&helper); + if (query) + { + query->QueryInterface(&helper); + } InstanceSelector selector(args, helper); auto instances = selector.Select(e); @@ -99,6 +70,45 @@ int wmain(_In_ int argc, _In_ LPCWSTR argv[]) return E_FAIL; } +void GetEnumerator(_In_ const CommandArgs& args, _In_ ISetupConfigurationPtr& query, _In_ IEnumSetupInstancesPtr& e) +{ + auto hr = query.CreateInstance(__uuidof(SetupConfiguration)); + if (FAILED(hr)) + { + if (REGDB_E_CLASSNOTREG == hr) + { + return; + } + + throw win32_error(hr); + } + + // If all instances are requested, try to get the proper enumerator; otherwise, fall back to original enumerator. + if (args.get_All()) + { + ISetupConfiguration2Ptr query2; + + hr = query->QueryInterface(&query2); + if (SUCCEEDED(hr)) + { + hr = query2->EnumAllInstances(&e); + if (FAILED(hr)) + { + throw win32_error(hr); + } + } + } + + if (!e) + { + hr = query->EnumInstances(&e); + if (FAILED(hr)) + { + throw win32_error(hr); + } + } +} + void WriteLogo(_In_ const CommandArgs& args, _In_ wostream& out) { if (args.get_Logo()) diff --git a/test/vswhere.test/InstanceSelectorTests.cpp b/test/vswhere.test/InstanceSelectorTests.cpp index 6914dda..1bebe6c 100644 --- a/test/vswhere.test/InstanceSelectorTests.cpp +++ b/test/vswhere.test/InstanceSelectorTests.cpp @@ -11,6 +11,17 @@ using namespace Microsoft::VisualStudio::CppUnitTestFramework; TEST_CLASS(InstanceSelectorTests) { public: + TEST_METHOD(Select_Null) + { + CommandArgs args; + args.Parse(L"vswhere.exe"); + + InstanceSelector sut(args); + auto selected = sut.Select(NULL); + + Assert::AreEqual(0, selected.size()); + } + TEST_METHOD(Select_No_Product) { TestInstance instance =