Skip to content

Commit

Permalink
dpcpp does not instantiate Generic from GenericHelper
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jun 12, 2023
1 parent eb3a609 commit 8e8e437
Show file tree
Hide file tree
Showing 19 changed files with 84 additions and 78 deletions.
3 changes: 1 addition & 2 deletions extensions/file_config/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ target_sources(file_config_custom


target_link_libraries(file_config_custom PUBLIC nlohmann_json::nlohmann_json Ginkgo::ginkgo)
# target_include_directories(file_config_custom PUBLIC "include")
target_include_directories(file_config_custom PUBLIC
$<BUILD_INTERFACE:${Ginkgo_SOURCE_DIR}/extensions/file_config/include>
$<INSTALL_INTERFACE:include/extensions/file_config>
Expand Down Expand Up @@ -82,4 +81,4 @@ install(EXPORT Ginkgo
DESTINATION "${GINKGO_INSTALL_CONFIG_DIR}")
if(GINKGO_BUILD_TESTS)
add_subdirectory(test)
endif()
endif()
2 changes: 1 addition & 1 deletion extensions/file_config/GkoFileConfigConfig.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

find_package(nlohmann_json 3.11.2 REQUIRED)

include(${CMAKE_CURRENT_LIST_DIR}/GinkgoTargets.cmake)
include(${CMAKE_CURRENT_LIST_DIR}/GkoFileConfig.cmake)
10 changes: 5 additions & 5 deletions extensions/file_config/base/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct Generic<gko::CudaExecutor> {
std::cout << "Cuda" << std::endl;
auto device_id = get_value_with_default(item, "device_id", 0);
auto ptr = CudaExecutor::create(device_id, ReferenceExecutor::create());
// add_logger(ptr, item, exec, linop, manager);
add_logger(ptr, item, exec, linop, manager);
return std::move(ptr);
}
};
Expand All @@ -82,7 +82,7 @@ struct Generic<gko::HipExecutor> {
std::cout << "Hip" << std::endl;
auto device_id = get_value_with_default(item, "device_id", 0);
auto ptr = HipExecutor::create(device_id, ReferenceExecutor::create());
// add_logger(ptr, item, exec, linop, manager);
add_logger(ptr, item, exec, linop, manager);
return std::move(ptr);
}
};
Expand All @@ -105,7 +105,7 @@ struct Generic<gko::DpcppExecutor> {
get_value_with_default(item, "device_type", std::string("all"));
auto ptr = DpcppExecutor::create(device_id, ReferenceExecutor::create(),
device_type);
// add_logger(ptr, item, exec, linop, manager);
add_logger(ptr, item, exec, linop, manager);
return std::move(ptr);
}
};
Expand All @@ -124,7 +124,7 @@ struct Generic<gko::ReferenceExecutor> {
{
std::cout << "Reference" << std::endl;
auto ptr = ReferenceExecutor::create();
// add_logger(ptr, item, exec, linop, manager);
add_logger(ptr, item, exec, linop, manager);
return std::move(ptr);
}
};
Expand All @@ -143,7 +143,7 @@ struct Generic<gko::OmpExecutor> {
{
std::cout << "Omp" << std::endl;
auto ptr = OmpExecutor::create();
// add_logger(ptr, item, exec, linop, manager);
add_logger(ptr, item, exec, linop, manager);
return std::move(ptr);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ std::shared_ptr<T> create_from_config(

/**
* create_from_config is another overloading to implement the function after
* selection on enum map.
* selection on enum map. This is the major implementation to select different
* template type from base class.
*
* @tparam T the enum type
* @tparam base the enum item
Expand Down
30 changes: 25 additions & 5 deletions extensions/file_config/include/file_config/base/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,33 @@ namespace file_config {
* contains name.
*/
template <typename T>
std::shared_ptr<T> call(const nlohmann::json& item,
std::shared_ptr<const Executor> exec,
std::shared_ptr<const LinOp> linop,
ResourceManager* manager)
inline typename std::enable_if<!is_on_linopfactory<T>::value &&
!is_on_criterionfactory<T>::value,
std::shared_ptr<T>>::type
call(const nlohmann::json& item, std::shared_ptr<const Executor> exec,
std::shared_ptr<const LinOp> linop, ResourceManager* manager)
{
if (manager == nullptr) {
return GenericHelper<T>::build(item, exec, linop, manager);
return Generic<T>::build(item, exec, linop, manager);
} else {
std::cout << exec.get() << std::endl;
return manager->build_item<T>(item, exec, linop);
}
}

// In dpcpp, GenericHelper static function does not instantiate the
// corresponding Generic function. We use inline function to instantiate all
// possible template
template <typename T>
inline typename std::enable_if<is_on_linopfactory<T>::value ||
is_on_criterionfactory<T>::value,
std::shared_ptr<T>>::type
call(const nlohmann::json& item, std::shared_ptr<const Executor> exec,
std::shared_ptr<const LinOp> linop, ResourceManager* manager)
{
if (manager == nullptr) {
return Generic<T, typename T::base_type>::build(item, exec, linop,
manager);
} else {
std::cout << exec.get() << std::endl;
return manager->build_item<T>(item, exec, linop);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ inline std::shared_ptr<T> get_pointer(const nlohmann::json& item,
assert(false);
}
} else {
// assert(false);
// TODO: manager
if (item.is_string()) {
std::cout << "search item" << std::endl;
std::string opt = item.get<std::string>();
Expand Down Expand Up @@ -279,8 +277,6 @@ inline std::shared_ptr<const LinOp> get_pointer<const LinOp>(
assert(false);
}
} else {
// TODO: manager
// assert(false);
if (item.is_string()) {
std::string opt = item.get<std::string>();
if (opt == std::string("given")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace gko {
namespace extensions {
namespace file_config {


using ExecutorMap = std::unordered_map<std::string, std::shared_ptr<Executor>>;
using LinOpMap = std::unordered_map<std::string, std::shared_ptr<LinOp>>;
using LinOpFactoryMap =
Expand Down Expand Up @@ -328,7 +329,6 @@ inline void ResourceManager::build_item(const nlohmann::json& item)
std::string name = item.at("name").get<std::string>();
std::string base = get_base_class(item["base"].get<std::string>());

// if (base == std::string{})
{
auto ptr =
create_from_config<Executor>(item, base, nullptr, nullptr, this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace extensions {
namespace file_config {


// return the input string without space
inline std::string remove_space(const std::string& str)
{
std::string nospace = str;
Expand All @@ -54,13 +55,16 @@ inline std::string remove_space(const std::string& str)
return nospace;
}


// get the base class of input
inline std::string get_base_class(const std::string& str)
{
auto langle_pos = str.find("<");
return remove_space(str.substr(0, langle_pos));
}


// get the template string in the first pair of <>
inline std::string get_base_template(const std::string& str)
{
auto langle_pos = str.find("<");
Expand All @@ -74,6 +78,7 @@ inline std::string get_base_template(const std::string& str)
}


// find the position of seperator `,` of input string
inline std::size_t find_template_sep(const std::string& str,
std::size_t pos = 0)
{
Expand All @@ -90,6 +95,8 @@ inline std::size_t find_template_sep(const std::string& str,
}


// Base on the base_template and type_template to decide the final template
// string. The base_template has higher priority than type_template.
inline std::string combine_template(const std::string& base_template,
const std::string& type_template)
{
Expand Down
20 changes: 11 additions & 9 deletions extensions/file_config/include/file_config/base/type_default.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,30 @@ namespace file_config {
// denote the type we have the default supported list
enum class handle_type { ValueType, IndexType };


// tt_list_g::type will give the tt_list<supported types>
template <handle_type enum_item>
struct tt_list_g;

#define TT_LIST_G_PARTIAL(_enum, ...) \
template <> \
struct tt_list_g<handle_type::_enum> { \
using type = tt_list<__VA_ARGS__>; \
}

// tt_list_g::type will give the tt_list<supported types>
template <handle_type enum_item>
struct tt_list_g;

TT_LIST_G_PARTIAL(ValueType, double, float, std::complex<float>,
std::complex<double>);
TT_LIST_G_PARTIAL(ValueType, double, float, std::complex<double>,
std::complex<float>);
TT_LIST_G_PARTIAL(IndexType, int32, int64);


template <handle_type T>
using tt_list_g_t = typename tt_list_g<T>::type;


// return the default template
template <handle_type enum_item>
inline std::string get_default_string();

#define GET_DEFAULT_STRING_PARTIAL(_enum, _type) \
template <> \
inline std::string get_default_string<handle_type::_enum>() \
Expand All @@ -78,9 +83,6 @@ using tt_list_g_t = typename tt_list_g<T>::type;
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")
// return the default template
template <handle_type enum_item>
inline std::string get_default_string();

GET_DEFAULT_STRING_PARTIAL(ValueType, double);
GET_DEFAULT_STRING_PARTIAL(IndexType, int);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ inline std::string get_string(type_list<K>)
*/
template <typename K, typename... Rest>
inline typename std::enable_if<(sizeof...(Rest) > 0), std::string>::type
get_string(type_list<K, Rest...>)
get_string(type_list<K, Rest...>)
{
return get_string<K>() + "," + get_string(type_list<Rest...>());
}
Expand Down
9 changes: 5 additions & 4 deletions extensions/file_config/include/file_config/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ using CriterionFactory = ::gko::stop::CriterionFactory;
using Logger = ::gko::log::Logger;


#define ENUM_EXECUTER(_expand, _sep) \
_expand(CudaExecutor) _sep _expand(DpcppExecutor) \
_sep _expand(HipExecutor) _sep _expand(OmpExecutor) \
_sep _expand(ReferenceExecutor) ENUM_EXECUTER_USER(_expand, _sep)
#define ENUM_EXECUTER(_expand, _sep) \
_expand(CudaExecutor) _sep _expand(DpcppExecutor) \
_sep _expand(HipExecutor) \
_sep _expand(OmpExecutor) \
_sep _expand(ReferenceExecutor) ENUM_EXECUTER_USER(_expand, _sep)

#define ENUM_LINOP(_expand, _sep) _expand(Csr) _sep _expand(Isai)

Expand Down
2 changes: 0 additions & 2 deletions extensions/file_config/log/convergence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ struct Generic<gko::log::Convergence<ValueType>> {
std::shared_ptr<const LinOp> linop,
ResourceManager* manager)
{
// auto exec_ptr =
// get_pointer_check<Executor>(item, "exec", exec, linop, manager);
auto mask_value = get_mask_value_with_default(
item, "enabled_events", gko::log::Logger::all_events_mask);
auto ptr = gko::log::Convergence<ValueType>::create(mask_value);
Expand Down
2 changes: 1 addition & 1 deletion extensions/file_config/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct Generic<gko::matrix::Csr<ValueType, IndexType>> {
ptr->read(data);
}

// add_logger(ptr, item, exec, linop, manager);
add_logger(ptr, item, exec, linop, manager);
return std::move(ptr);
}
};
Expand Down
8 changes: 4 additions & 4 deletions extensions/file_config/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ function(gkoext_file_config_create_test test_name)
)
set_target_properties(${TEST_TARGET_NAME} PROPERTIES
OUTPUT_NAME ${test_name})
# if (GINKGO_CHECK_CIRCULAR_DEPS)
# target_link_libraries(${TEST_TARGET_NAME} PRIVATE "${GINKGO_CIRCULAR_DEPS_FLAGS}")
# endif()
if (GINKGO_CHECK_CIRCULAR_DEPS)
target_link_libraries(${TEST_TARGET_NAME} PRIVATE "${GINKGO_CIRCULAR_DEPS_FLAGS}")
endif()
target_link_libraries(${TEST_TARGET_NAME} PRIVATE Ginkgo::ginkgo file_config GTest::Main GTest::GTest nlohmann_json::nlohmann_json ${ARGN})
add_test(NAME ${REL_BINARY_DIR}/${test_name}
COMMAND ${TEST_TARGET_NAME}
Expand All @@ -27,4 +27,4 @@ add_subdirectory(executor)
add_subdirectory(log)
add_subdirectory(matrix)
add_subdirectory(preconditioner)
add_subdirectory(stop)
add_subdirectory(stop)
1 change: 1 addition & 0 deletions extensions/file_config/test/base/type_resolving.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct DummyType {
};
};


template <typename K, typename T>
struct DummyType2 {
using ktype = K;
Expand Down
2 changes: 2 additions & 0 deletions extensions/file_config/test/base/type_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ TEST(GetString, GetStringFromTypeList)
ASSERT_EQ(get_string(type_list<int, double>{}), "int,double");
}


TEST(GetString, GetStringFromComplex)
{
ASSERT_EQ(get_string<std::complex<float>>(), "complex<float>");
ASSERT_EQ(get_string<std::complex<double>>(), "complex<double>");
}


TEST(GetString, GetStringFromBase)
{
ASSERT_EQ(get_string<gko::solver::LowerTrs<>>(), "LowerTrs<double,int>");
Expand Down
21 changes: 16 additions & 5 deletions extensions/file_config/test/custom/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/executor.hpp>


#include <exception>


#include <gtest/gtest.h>
#include <nlohmann/json.hpp>

Expand All @@ -52,10 +55,18 @@ namespace extensions {
namespace file_config {


IMPLEMENT_EMPTY_BRIDGE(RM_Executor, TestExecutor);
template <>
inline std::shared_ptr<Executor>
create_from_config<RM_Executor, RM_Executor::TestExecutor, Executor>(
const nlohmann::json& item, std::shared_ptr<const Executor> exec,
std::shared_ptr<const LinOp> linop, ResourceManager* manager)
{
throw std::runtime_error("TestExecutor");
return nullptr;
}


}
} // namespace file_config
} // namespace extensions
} // namespace gko

Expand All @@ -70,8 +81,8 @@ TEST(ReferenceExecutor, CreateCorrectCustomExecutor)
{"base": "TestExecutor"}
)");

auto ptr =
gko::extensions::file_config::create_from_config<gko::Executor>(data);

ASSERT_EQ(ptr.get(), nullptr);
ASSERT_THROW(
gko::extensions::file_config::create_from_config<gko::Executor>(data),
std::runtime_error);
}
Loading

0 comments on commit 8e8e437

Please sign in to comment.