Skip to content

Commit

Permalink
Add ScopedUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jul 1, 2022
1 parent 9dfe9b8 commit c6ca964
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 0 deletions.
131 changes: 131 additions & 0 deletions include/llama/VirtualRecord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#pragma once

#include "Concepts.hpp"
#include "HasRanges.hpp"
#include "ProxyRefOpMixin.hpp"
#include "View.hpp"

#include <iosfwd>
Expand Down Expand Up @@ -783,6 +785,135 @@ namespace llama
[functor = std::forward<Functor>(functor), &vr = vr](auto rc)
LLAMA_LAMBDA_INLINE_WITH_SPECIFIERS(constexpr mutable) { std::forward<Functor>(functor)(vr(rc)); });
}

namespace internal
{
// gets the value type for a given T, where T models a reference type. T is either an l-value reference, a
// proxy reference or a VirtualRecord
template<typename T, typename = void>
struct ValueOf
{
static_assert(sizeof(T) == 0, "T does not model a reference");
};

template<typename T>
struct ValueOf<T, std::enable_if_t<is_VirtualRecord<T>>>
{
using type = One<typename T::AccessibleRecordDim>;
};

#ifdef __cpp_lib_concepts
template<ProxyReference T>
#else
template<typename T>
#endif
struct ValueOf<T, std::enable_if_t<isProxyReference<T>>>
{
using type = typename T::value_type;
};

template<typename T>
struct ValueOf<T&>
{
using type = T;
};
} // namespace internal

template<typename Reference, typename = void>
struct ScopedUpdate : internal::ValueOf<Reference>::type
{
using value_type = typename internal::ValueOf<Reference>::type;

Reference ref;

LLAMA_FN_HOST_ACC_INLINE ScopedUpdate(Reference r) : value_type(r), ref(r)
{
}

ScopedUpdate(const ScopedUpdate&) = delete;
auto operator=(const ScopedUpdate&) -> ScopedUpdate& = delete;

ScopedUpdate(ScopedUpdate&&) noexcept = default;
auto operator=(ScopedUpdate&&) noexcept -> ScopedUpdate& = default;

using value_type::operator=;

LLAMA_FN_HOST_ACC_INLINE ~ScopedUpdate()
{
ref = static_cast<value_type&>(*this);
}
};

template<typename Reference>
struct ScopedUpdate<
Reference,
std::enable_if_t<std::is_fundamental_v<typename internal::ValueOf<Reference>::type>>>
: ProxyRefOpMixin<ScopedUpdate<Reference>, typename internal::ValueOf<Reference>::type>
{
using value_type = typename internal::ValueOf<Reference>::type;

value_type value;
Reference ref;

LLAMA_FN_HOST_ACC_INLINE ScopedUpdate(Reference r) : value(r), ref(r)
{
}

ScopedUpdate(const ScopedUpdate&) = delete;
auto operator=(const ScopedUpdate&) -> ScopedUpdate& = delete;

ScopedUpdate(ScopedUpdate&&) noexcept = default;
auto operator=(ScopedUpdate&&) noexcept -> ScopedUpdate& = default;

LLAMA_FN_HOST_ACC_INLINE operator const value_type&() const
{
return value;
}

LLAMA_FN_HOST_ACC_INLINE operator value_type&()
{
return value;
}

LLAMA_FN_HOST_ACC_INLINE auto operator=(value_type v) -> ScopedUpdate&
{
value = v;
return *this;
}

LLAMA_FN_HOST_ACC_INLINE ~ScopedUpdate()
{
ref = value;
}
};

namespace internal
{
template<typename T, typename = void>
struct ReferenceTo
{
using type = T&;
};

template<typename T>
struct ReferenceTo<T, std::enable_if_t<is_VirtualRecord<T> && !is_One<T>>>
{
using type = T;
};

#ifdef __cpp_lib_concepts
template<ProxyReference T>
#else
template<typename T>
#endif
struct ReferenceTo<T, std::enable_if_t<isProxyReference<T>>>
{
using type = T;
};
} // namespace internal

template<typename T>
ScopedUpdate(T) -> ScopedUpdate<typename internal::ReferenceTo<std::remove_reference_t<T>>::type>;
} // namespace llama

template<typename View, typename BoundRecordCoord, bool OwnView>
Expand Down
106 changes: 106 additions & 0 deletions tests/virtualrecord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,3 +1083,109 @@ TEST_CASE("VirtualRecord.reference_to_One")
CHECK(v(tag::Y{}) == 22);
CHECK(v(tag::Z{}) == 3);
}

TEST_CASE("ValueOf")
{
STATIC_REQUIRE(std::is_same_v<llama::internal::ValueOf<int&>::type, int>);

using One = llama::One<Vec3I>;
STATIC_REQUIRE(std::is_same_v<llama::internal::ValueOf<decltype(One{}())>::type, One>);

auto mapping = llama::mapping::BitPackedIntSoA<llama::ArrayExtents<int, 4>, Vec3I>{{}, 17};
auto v = llama::allocView(mapping);
[[maybe_unused]] auto ref = v(1)(tag::X{});
#ifdef __cpp_lib_concepts
STATIC_REQUIRE(llama::ProxyReference<decltype(ref)>);
#endif
STATIC_REQUIRE(std::is_same_v<llama::internal::ValueOf<decltype(ref)>::type, int>);
}
TEST_CASE("ScopedUpdate.Fundamental")
{
int i = 1;
{
llama::ScopedUpdate u(i);
STATIC_REQUIRE(std::is_same_v<decltype(u), llama::ScopedUpdate<int&>>);
u = 23;
CHECK(u == 23);
CHECK(i == 1);
u = 24;
CHECK(u == 24);
CHECK(i == 1);
}
CHECK(i == 24);
}

TEST_CASE("ScopedUpdate.Object")
{
std::vector v = {1};
{
llama::ScopedUpdate u(v);
STATIC_REQUIRE(std::is_same_v<decltype(u), llama::ScopedUpdate<std::vector<int>&>>);
u.push_back(2);
CHECK(u == std::vector{1, 2});
CHECK(v == std::vector{1});
u = std::vector{3, 4, 5};
CHECK(u == std::vector{3, 4, 5});
CHECK(v == std::vector{1});
}
CHECK(v == std::vector{3, 4, 5});
}

TEST_CASE("ScopedUpdate.ProxyRef")
{
auto mapping = llama::mapping::BitPackedIntSoA<llama::ArrayExtents<int, 4>, Vec3I>{{}, 17};
auto v = llama::allocView(mapping);
auto i = v(1)(tag::X{});
i = 1;
{
llama::ScopedUpdate u(i);
STATIC_REQUIRE(std::is_same_v<decltype(u), llama::ScopedUpdate<decltype(i)>>);
u = 23;
CHECK(u == 23);
CHECK(i == 1);
u = 24;
CHECK(u == 24);
CHECK(i == 1);
}
CHECK(i == 24);
}

TEST_CASE("ScopedUpdate.VirtualRecord")
{
auto test = [](auto&& v)
{
llama::forEachLeaf(v, [i = 0](auto& field) mutable { field = ++i; });
{
llama::ScopedUpdate u(v);
if constexpr(llama::is_One<std::remove_reference_t<decltype(v)>>)
{
STATIC_REQUIRE(
std::is_same_v<decltype(u), llama::ScopedUpdate<std::remove_reference_t<decltype(v)>&>>);
}
else
{
STATIC_REQUIRE(std::is_same_v<decltype(u), llama::ScopedUpdate<decltype(v())>>);
}
u(tag::X{}) = 11;
CHECK(u(tag::X{}) == 11);
CHECK(u(tag::Y{}) == 2);
CHECK(u(tag::Z{}) == 3);
CHECK(v(tag::X{}) == 1);
CHECK(v(tag::Y{}) == 2);
CHECK(v(tag::Z{}) == 3);
u = 24;
CHECK(u(tag::X{}) == 24);
CHECK(u(tag::Y{}) == 24);
CHECK(u(tag::Z{}) == 24);
CHECK(v(tag::X{}) == 1);
CHECK(v(tag::Y{}) == 2);
CHECK(v(tag::Z{}) == 3);
}
CHECK(v(tag::X{}) == 24);
CHECK(v(tag::Y{}) == 24);
CHECK(v(tag::Z{}) == 24);
};
llama::One<Vec3I> v;
test(v);
test(v());
}

0 comments on commit c6ca964

Please sign in to comment.