diff --git a/include/llama/Iterator.hpp b/include/llama/Iterator.hpp new file mode 100644 index 0000000000..3488b0eb75 --- /dev/null +++ b/include/llama/Iterator.hpp @@ -0,0 +1,102 @@ +#include "View.hpp" + +#include +//#include + +namespace llama +{ + // requires Boost 1.74 which is quite new +#if 0 + template + struct Iterator + : boost::stl_interfaces::proxy_iterator_interface< + Iterator, + std::random_access_iterator_tag, + decltype(std::declval()(ArrayDomain<1>{}))> + { + constexpr decltype(auto) operator*() const + { + return (*view)(coord); + } + + constexpr auto operator+=(std::ptrdiff_t n) -> Iterator& + { + coord[0] += n; + return *this; + } + + friend constexpr auto operator-(const Iterator& a, const Iterator& b) + { + return a.coord[0] - b.coord[0]; + } + + friend constexpr bool operator==(const Iterator& a, const Iterator& b) + { + return a.coord == b.coord; + } + + ArrayDomain<1> coord; + View* view; + }; +#endif + + template + struct Iterator + : boost::iterators::iterator_facade< + Iterator, + typename View::VirtualDatumType, + std::random_access_iterator_tag, + typename View::VirtualDatumType, + std::ptrdiff_t> + { + constexpr decltype(auto) dereference() const + { + return (*view)(coord); + } + + constexpr bool equal(const Iterator& other) const + { + return coord == other.coord; + } + + constexpr auto increment() -> Iterator& + { + coord[0]++; + return *this; + } + + constexpr auto decrement() -> Iterator& + { + coord[0]--; + return *this; + } + + constexpr auto advance(std::ptrdiff_t n) -> Iterator& + { + coord[0] += n; + return *this; + } + + constexpr auto distance_to(const Iterator& other) const -> std::ptrdiff_t + { + return static_cast(other.coord[0]) - static_cast(coord[0]); + } + + ArrayDomain<1> coord; + View* view; + }; + + template + auto begin(View& view) -> Iterator + { + static_assert(View::ArrayDomain::rank == 1, "Iterators for non-1D views are not implemented"); + return {{}, ArrayDomain<1>{}, &view}; + } + + template + auto end(View& view) -> Iterator + { + static_assert(View::ArrayDomain::rank == 1, "Iterators for non-1D views are not implemented"); + return {{}, view.mapping.arrayDomainSize, &view}; + } +} // namespace llama diff --git a/include/llama/llama.hpp b/include/llama/llama.hpp index 82116e1f51..7a47412dac 100644 --- a/include/llama/llama.hpp +++ b/include/llama/llama.hpp @@ -34,6 +34,7 @@ #include "Allocators.hpp" #include "ArrayDomainRange.hpp" #include "Core.hpp" +#include "Iterator.hpp" #include "View.hpp" #include "macros.hpp" #include "mapping/AoS.hpp" diff --git a/tests/iterator.cpp b/tests/iterator.cpp new file mode 100644 index 0000000000..7604a29ba9 --- /dev/null +++ b/tests/iterator.cpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include + +// clang-format off +namespace tag { + struct X {}; + struct Y {}; + struct Z {}; +} + +using Position = llama::DS< + llama::DE, + llama::DE, + llama::DE +>; +// clang-format on + +TEST_CASE("iterator") +{ + using ArrayDomain = llama::ArrayDomain<1>; + constexpr auto arrayDomain = ArrayDomain{32}; + constexpr auto mapping = llama::mapping::AoS{arrayDomain}; + auto view = llama::allocView(mapping); + + for (auto vd : view) + { + vd(tag::X{}) = 1; + vd(tag::Y{}) = 2; + vd(tag::Z{}) = 3; + } + std::transform(begin(view), end(view), begin(view), [](auto vd) { return vd * 2; }); + const int sumY = std::accumulate(begin(view), end(view), 0, [](int acc, auto vd) { return acc + vd(tag::Y{}); }); + CHECK(sumY == 128); +} + +TEST_CASE("iterator.std_copy") +{ + using ArrayDomain = llama::ArrayDomain<1>; + constexpr auto arrayDomain = ArrayDomain{32}; + auto aosView = llama::allocView(llama::mapping::AoS{arrayDomain}); + auto soaView = llama::allocView(llama::mapping::SoA{arrayDomain}); + + int i = 0; + for (auto vd : aosView) + { + vd(tag::X{}) = ++i; + vd(tag::Y{}) = ++i; + vd(tag::Z{}) = ++i; + } + std::copy(begin(aosView), end(aosView), begin(soaView)); + i = 0; + for (auto vd : soaView) + { + CHECK(vd(tag::X{}) == ++i); + CHECK(vd(tag::Y{}) == ++i); + CHECK(vd(tag::Z{}) == ++i); + } +}