From 53eba7524f2d47b3a54c3aa947f988314288c06a Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Fri, 25 Feb 2022 18:06:35 +0100 Subject: [PATCH] add dot product for Array --- include/llama/Array.hpp | 10 ++++++++++ tests/array.cpp | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/include/llama/Array.hpp b/include/llama/Array.hpp index af12abc052..0c76ab2326 100644 --- a/include/llama/Array.hpp +++ b/include/llama/Array.hpp @@ -208,6 +208,16 @@ namespace llama prod *= s; return prod; } + + template + LLAMA_FN_HOST_ACC_INLINE constexpr auto dot(Array a, Array b) -> T + { + T r = 0; + if constexpr(N > 0) + for(std::size_t i = 0; i < N; i++) + r += a[i] * b[i]; + return r; + } } // namespace llama namespace std diff --git a/tests/array.cpp b/tests/array.cpp index 05bfab83fc..7f4335351d 100644 --- a/tests/array.cpp +++ b/tests/array.cpp @@ -52,3 +52,11 @@ TEST_CASE("Array.product") STATIC_REQUIRE(product(llama::Array{1, 2}) == 2); STATIC_REQUIRE(product(llama::Array{3, 2, 1}) == 6); } + +TEST_CASE("Array.dot") +{ + STATIC_REQUIRE(llama::dot(llama::Array{}, llama::Array{}) == 0); + STATIC_REQUIRE(llama::dot(llama::Array{2}, llama::Array{3}) == 6); + STATIC_REQUIRE(llama::dot(llama::Array{4, 5}, llama::Array{6, 7}) == 59); + STATIC_REQUIRE(llama::dot(llama::Array{1, 2, 3, 4}, llama::Array{-5, 6, -7, 0}) == -14); +}