diff --git a/src/tensor.h b/src/tensor.h index 271d4c5..ea964fe 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -494,6 +494,60 @@ struct Tensor final { throw std::runtime_error("softmax: unsupported shape size"); } + Tensor argmax(size_t dimension) const { + if (dimension >= shape.size()) { + throw std::runtime_error(fmt::format( + "argmax: invalid dimension {} for shape {}", dimension, shape)); + } + std::vector newShape(shape); + newShape.erase(newShape.begin() + dimension); + if (newShape.size() == 0) { + newShape.push_back(1); + } + Tensor result(newShape); + + if (shape.size() == 1) { + T maxValue = data[offset]; + size_t maxIndex = 0; + for (size_t dim0 = 1; dim0 < shape[0]; dim0++) { + T currentValue = data[offset + dim0 * strides[0]]; + if (currentValue > maxValue) { + maxValue = currentValue; + maxIndex = dim0; + } + } + result.data[0] = maxIndex; + return result; + } + + if (shape.size() == 2) { + // Index into strides and shape, based on which dimension (row or column) + // we calculate on. + int i0 = 1; + int i1 = 0; + if (dimension == 1) { + i0 = 0; + i1 = 1; + } + for (size_t dim0 = 0; dim0 < shape[i0]; dim0++) { + T maxValue = std::numeric_limits::min(); + size_t maxIndex = 0; + for (size_t dim1 = 0; dim1 < shape[i1]; dim1++) { + T currentValue = + data[offset + dim0 * strides[i0] + dim1 * strides[i1]]; + if (currentValue > maxValue) { + maxValue = currentValue; + maxIndex = dim1; + } + } + result.data[dim0] = maxIndex; + } + return result; + } + + throw std::runtime_error("argmax: unsupported shape size"); + } + Tensor sum(size_t dimension = 0) const { if (dimension >= shape.size()) { throw std::runtime_error(fmt::format( diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index 5f9881c..a3ed1ee 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -10,6 +10,7 @@ add_executable(cppdl_test ${CMAKE_CURRENT_SOURCE_DIR}/tensor_reshape.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_sum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_argmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nn_activation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nn_linear.cpp ) diff --git a/unittests/tensor_argmax.cpp b/unittests/tensor_argmax.cpp new file mode 100644 index 0000000..6b5f151 --- /dev/null +++ b/unittests/tensor_argmax.cpp @@ -0,0 +1,28 @@ +#include + +#include "tensor.h" + +TEST(TensorArgmax, Argmax1D) { + Tensor t = Tensor::vector({1.0f, 3.0f, 2.0f, 4.0f}); + auto argmax = t.argmax(0); + EXPECT_EQ(argmax.item(), 3); +} + +TEST(TensorArgmax, Argmax2D_Dim0) { + Tensor t = + Tensor::matrix2d({{3.0f, 2.0f}, {1.0f, 4.0f}, {5.0f, 1.0f}}); + auto argmax = t.argmax(0); + EXPECT_EQ(argmax.shape, std::vector({2})); + EXPECT_EQ(argmax[0].item(), 2); + EXPECT_EQ(argmax[1].item(), 1); +} + +TEST(TensorArgmax, Argmax2D_Dim1) { + Tensor t = + Tensor::matrix2d({{1.0f, 4.0f}, {3.0f, 2.0f}, {5.0f, 6.0f}}); + auto argmax = t.argmax(1); + EXPECT_EQ(argmax.shape, std::vector({3})); + EXPECT_EQ(argmax[0].item(), 1); + EXPECT_EQ(argmax[1].item(), 0); + EXPECT_EQ(argmax[2].item(), 1); +}