diff --git a/src/nn.cpp b/src/nn.cpp index 5e4515c..a7e52db 100644 --- a/src/nn.cpp +++ b/src/nn.cpp @@ -34,8 +34,7 @@ int main() { auto flatResult = result.reshape({100}); auto loss = (flatResult - datasetLabels).reshape({100, 1}); - auto lossSum = Tensor::ones({1, loss.shape[0]}).matmul(loss); - fmt::println("Loss: {}", lossSum[0].item()); + fmt::println("Loss: {}", loss.sum().item()); // Backwards pass. layer0.zeroGrad(); diff --git a/src/nn.h b/src/nn.h index 3b20e86..475c212 100644 --- a/src/nn.h +++ b/src/nn.h @@ -28,8 +28,7 @@ class LinearLayer { Tensor backward(const Tensor &input, const Tensor &outGrad) { - auto outGradSum = - Tensor::ones({1, outGrad.shape[0]}).matmul(outGrad); + auto outGradSum = outGrad.sum(); biasGrad = biasGrad + outGradSum; weightGrad = weightGrad + outGrad.transpose().matmul(input); return outGrad.matmul(weight); diff --git a/src/tensor.h b/src/tensor.h index 6f14e7f..03ab7e3 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -494,6 +494,49 @@ struct Tensor final { throw std::runtime_error("softmax: unsupported shape size"); } + Tensor sum(size_t dimension = 0) const { + if (dimension >= shape.size()) { + throw std::runtime_error(fmt::format( + "sum: invalid dimension {} for shape {}", dimension, shape)); + } + std::vector newShape(shape); + newShape.erase(newShape.begin() + dimension); + if (newShape.size() == 0) { + newShape.push_back(1); + } + Tensor result(newShape); + + if (shape.size() == 1) { + float sum = 0.0f; + for (size_t dim0 = 0; dim0 < shape[0]; dim0++) { + sum += data[offset + dim0 * strides[0]]; + } + result.data.get()[0] = sum; + return result; + } + + if (shape.size() == 2) { + // Index into strides and shape, based on which dimension (row or column) + // we calculate on. + int i0 = 1; + int i1 = 0; + if (dimension == 1) { + i0 = 0; + i1 = 1; + } + for (size_t dim0 = 0; dim0 < shape[i0]; dim0++) { + float sum = 0.0f; + for (size_t dim1 = 0; dim1 < shape[i1]; dim1++) { + sum += data[offset + dim0 * strides[i0] + dim1 * strides[i1]]; + } + result.data.get()[dim0] = sum; + } + return result; + } + + throw std::runtime_error("softmax: unsupported shape size"); + } + friend std::ostream &operator<<(std::ostream &os, const Tensor &t) { os << t.toString(); return os; diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index 01ef0ad..5f9881c 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -8,9 +8,10 @@ add_executable(cppdl_test ${CMAKE_CURRENT_SOURCE_DIR}/tensor_stack.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_transpose.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_reshape.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_sum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nn_activation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nn_linear.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/tensor_softmax.cpp ) target_include_directories(cppdl_test PRIVATE ../src) diff --git a/unittests/nn_linear.cpp b/unittests/nn_linear.cpp index da6f02a..d05a276 100644 --- a/unittests/nn_linear.cpp +++ b/unittests/nn_linear.cpp @@ -22,7 +22,7 @@ TEST_F(LinearLayerTest, Backward) { auto activations = linearLayer.forward(input); outGrad = linearLayer.backward(input, outGrad); - ASSERT_EQ(linearLayer.biasGrad, Tensor::matrix2d({{2.0f, 4.5f}})); + ASSERT_EQ(linearLayer.biasGrad, Tensor::vector({2.0f, 4.5f})); ASSERT_EQ( linearLayer.weightGrad, Tensor::matrix2d({{6.5f, 8.5f, 10.5f}, {12.0f, 16.5f, 21.0f}})); diff --git a/unittests/tensor_sum.cpp b/unittests/tensor_sum.cpp new file mode 100644 index 0000000..2121595 --- /dev/null +++ b/unittests/tensor_sum.cpp @@ -0,0 +1,23 @@ +#include + +#include "tensor.h" + +TEST(TensorSum, Sum1D) { + Tensor t = Tensor::vector({1.0f, 2.0f, 3.0f, 4.0f}); + auto summed = t.sum(); + EXPECT_NEAR(summed[0].item(), 10.0f, 1e-6); +} + +TEST(TensorSum, Sum2D_Dim0) { + Tensor t = Tensor::matrix2d({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto summed = t.sum(0); + EXPECT_NEAR(summed[0].item(), 4.0f, 1e-6); + EXPECT_NEAR(summed[1].item(), 6.0f, 1e-6); +} + +TEST(TensorSum, Sum2D_Dim1) { + Tensor t = Tensor::matrix2d({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto summed = t.sum(1); + EXPECT_NEAR(summed[0].item(), 3.0f, 1e-6); + EXPECT_NEAR(summed[1].item(), 7.0f, 1e-6); +}