Skip to content

Commit

Permalink
tensor: argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
fotcorn committed Feb 29, 2024
1 parent 4315a9b commit 7c115e2
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
54 changes: 54 additions & 0 deletions src/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,60 @@ struct Tensor final {
throw std::runtime_error("softmax: unsupported shape size");
}

Tensor<size_t> argmax(size_t dimension) const {
if (dimension >= shape.size()) {
throw std::runtime_error(fmt::format(
"argmax: invalid dimension {} for shape {}", dimension, shape));
}
std::vector newShape(shape);
newShape.erase(newShape.begin() + dimension);
if (newShape.size() == 0) {
newShape.push_back(1);
}
Tensor<size_t> result(newShape);

if (shape.size() == 1) {
T maxValue = data[offset];
size_t maxIndex = 0;
for (size_t dim0 = 1; dim0 < shape[0]; dim0++) {
T currentValue = data[offset + dim0 * strides[0]];
if (currentValue > maxValue) {
maxValue = currentValue;
maxIndex = dim0;
}
}
result.data[0] = maxIndex;
return result;
}

if (shape.size() == 2) {
// Index into strides and shape, based on which dimension (row or column)
// we calculate on.
int i0 = 1;
int i1 = 0;
if (dimension == 1) {
i0 = 0;
i1 = 1;
}
for (size_t dim0 = 0; dim0 < shape[i0]; dim0++) {
T maxValue = std::numeric_limits<T>::min();
size_t maxIndex = 0;
for (size_t dim1 = 0; dim1 < shape[i1]; dim1++) {
T currentValue =
data[offset + dim0 * strides[i0] + dim1 * strides[i1]];
if (currentValue > maxValue) {
maxValue = currentValue;
maxIndex = dim1;
}
}
result.data[dim0] = maxIndex;
}
return result;
}

throw std::runtime_error("argmax: unsupported shape size");
}

Tensor<T> sum(size_t dimension = 0) const {
if (dimension >= shape.size()) {
throw std::runtime_error(fmt::format(
Expand Down
1 change: 1 addition & 0 deletions unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_executable(cppdl_test
${CMAKE_CURRENT_SOURCE_DIR}/tensor_reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_sum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_argmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nn_activation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nn_linear.cpp
)
Expand Down
28 changes: 28 additions & 0 deletions unittests/tensor_argmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <gtest/gtest.h>

#include "tensor.h"

TEST(TensorArgmax, Argmax1D) {
Tensor<float> t = Tensor<float>::vector({1.0f, 3.0f, 2.0f, 4.0f});
auto argmax = t.argmax(0);
EXPECT_EQ(argmax.item(), 3);
}

TEST(TensorArgmax, Argmax2D_Dim0) {
Tensor<float> t =
Tensor<float>::matrix2d({{3.0f, 2.0f}, {1.0f, 4.0f}, {5.0f, 1.0f}});
auto argmax = t.argmax(0);
EXPECT_EQ(argmax.shape, std::vector<size_t>({2}));
EXPECT_EQ(argmax[0].item(), 2);
EXPECT_EQ(argmax[1].item(), 1);
}

TEST(TensorArgmax, Argmax2D_Dim1) {
Tensor<float> t =
Tensor<float>::matrix2d({{1.0f, 4.0f}, {3.0f, 2.0f}, {5.0f, 6.0f}});
auto argmax = t.argmax(1);
EXPECT_EQ(argmax.shape, std::vector<size_t>({3}));
EXPECT_EQ(argmax[0].item(), 1);
EXPECT_EQ(argmax[1].item(), 0);
EXPECT_EQ(argmax[2].item(), 1);
}

0 comments on commit 7c115e2

Please sign in to comment.