diff --git a/include/gridtools/fn/sid_neighbor_table.hpp b/include/gridtools/fn/sid_neighbor_table.hpp index dd936dc65..a630fdc15 100644 --- a/include/gridtools/fn/sid_neighbor_table.hpp +++ b/include/gridtools/fn/sid_neighbor_table.hpp @@ -15,6 +15,7 @@ #include "../common/array.hpp" #include "../common/const_ptr_deref.hpp" #include "../fn/unstructured.hpp" +#include "../sid/as_const.hpp" #include "../sid/concept.hpp" namespace gridtools::fn::sid_neighbor_table { @@ -61,14 +62,15 @@ namespace gridtools::fn::sid_neighbor_table { static_assert(!std::is_same_v, "The index dimension and the neighbor dimension must be different."); - const auto origin = sid::get_origin(sid); - const auto strides = sid::get_strides(sid); + decltype(auto) const_sid = sid::as_const(std::forward(sid)); + const auto origin = sid::get_origin(const_sid); + const auto strides = sid::get_strides(const_sid); return sid_neighbor_table, - sid::strides_type>{ + decltype(origin), + decltype(strides)>{ origin, strides}; // Note: putting the return type into the function signature will crash nvcc 12.0 } } // namespace sid_neighbor_table_impl_ diff --git a/include/gridtools/sid/as_const.hpp b/include/gridtools/sid/as_const.hpp index 0e8f1ba8b..88d90cf75 100644 --- a/include/gridtools/sid/as_const.hpp +++ b/include/gridtools/sid/as_const.hpp @@ -51,14 +51,14 @@ namespace gridtools { * probably might we need the `host` and `device` variations as well */ template >, + class Ptr = sid::ptr_type>>, std::enable_if_t && !std::is_const_v>, int> = 0> as_const_impl_::const_adapter as_const(Src &&src) { return {std::forward(src)}; } template >, + class Ptr = sid::ptr_type>>, std::enable_if_t || std::is_const_v>, int> = 0> decltype(auto) as_const(Src &&src) { return std::forward(src); diff --git a/tests/unit_tests/fn/test_fn_sid_neighbor_table.cu b/tests/unit_tests/fn/test_fn_sid_neighbor_table.cu index dcdad7aed..cac45338e 100644 --- a/tests/unit_tests/fn/test_fn_sid_neighbor_table.cu +++ b/tests/unit_tests/fn/test_fn_sid_neighbor_table.cu @@ -42,7 +42,10 @@ namespace gridtools::fn { using dim_hymap_t = hymap::keys; auto contents = sid::synthetic() .set(sid::host_device::simple_ptr_holder(device_data.get())) - .set(dim_hymap_t::make_values(num_neighbors, 1)); + .set(dim_hymap_t::make_values(num_neighbors, 1)) + // for whatever reason, setting strides_kind is required + // by Clang-CUDA (tested Clang 17 + CUDA 12.4) + .set(); const auto table = as_neighbor_table(contents); using table_t = std::decay_t; diff --git a/tests/unit_tests/sid/test_sid_as_const.cpp b/tests/unit_tests/sid/test_sid_as_const.cpp index 7f32ebf98..e6af42e0c 100644 --- a/tests/unit_tests/sid/test_sid_as_const.cpp +++ b/tests/unit_tests/sid/test_sid_as_const.cpp @@ -32,5 +32,15 @@ namespace gridtools { static_assert(std::is_same_v, double const *>); EXPECT_EQ(sid::get_origin(src)(), sid::get_origin(testee)()); } + + TEST(as_const, c_array) { + int src[3][2] = {{0, 1}, {10, 11}, {20, 21}}; + auto testee = sid::as_const(src); + using testee_t = decltype(testee); + + static_assert(is_sid()); + static_assert(std::is_same_v, int const *>); + EXPECT_EQ(sid::get_origin(src)(), sid::get_origin(testee)()); + } } // namespace } // namespace gridtools