Skip to content

Commit

Permalink
add dot product for Array
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Feb 28, 2022
1 parent 9ec85ac commit 53eba75
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
10 changes: 10 additions & 0 deletions include/llama/Array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ namespace llama
prod *= s;
return prod;
}

template<typename T, std::size_t N>
LLAMA_FN_HOST_ACC_INLINE constexpr auto dot(Array<T, N> a, Array<T, N> b) -> T
{
T r = 0;
if constexpr(N > 0)
for(std::size_t i = 0; i < N; i++)
r += a[i] * b[i];
return r;
}
} // namespace llama

namespace std
Expand Down
8 changes: 8 additions & 0 deletions tests/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ TEST_CASE("Array.product")
STATIC_REQUIRE(product(llama::Array{1, 2}) == 2);
STATIC_REQUIRE(product(llama::Array{3, 2, 1}) == 6);
}

TEST_CASE("Array.dot")
{
STATIC_REQUIRE(llama::dot(llama::Array<int, 0>{}, llama::Array<int, 0>{}) == 0);
STATIC_REQUIRE(llama::dot(llama::Array{2}, llama::Array{3}) == 6);
STATIC_REQUIRE(llama::dot(llama::Array{4, 5}, llama::Array{6, 7}) == 59);
STATIC_REQUIRE(llama::dot(llama::Array{1, 2, 3, 4}, llama::Array{-5, 6, -7, 0}) == -14);
}

0 comments on commit 53eba75

Please sign in to comment.