Skip to content

Commit

Permalink
add begin/end to ExecuTorch pytree::arr (#8300)
Browse files Browse the repository at this point in the history
Pull Request resolved: #7653

Allows use of range-for.
ghstack-source-id: 265152271
@exported-using-ghexport

Differential Revision: [D68166302](https://our.internmc.facebook.com/intern/diff/D68166302/)

Co-authored-by: Github Executorch <github_executorch@arm.com>
  • Loading branch information
pytorchbot and Github Executorch authored Feb 7, 2025
1 parent 5cab322 commit 456928f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
4 changes: 2 additions & 2 deletions extension/pytree/aten_util/ivalue_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ std::pair<std::vector<at::Tensor>, std::unique_ptr<TreeSpec<Empty>>> flatten(
auto p = flatten(c);

std::vector<at::Tensor> tensors;
for (int i = 0; i < p.first.size(); ++i) {
tensors.emplace_back(p.first[i]->toTensor());
for (const auto& item : p.first) {
tensors.emplace_back(item->toTensor());
}

return {tensors, std::move(p.second)};
Expand Down
16 changes: 16 additions & 0 deletions extension/pytree/pytree.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,22 @@ struct arr {
return data_.get();
}

T* begin() {
return data_.get();
}

T* end() {
return begin() + size();
}

const T* begin() const {
return data_.get();
}

const T* end() const {
return begin() + size();
}

inline size_t size() const {
return n_;
}
Expand Down
14 changes: 14 additions & 0 deletions extension/pytree/test/test_pytree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,27 @@
#include <gtest/gtest.h>
#include <string>

using ::executorch::extension::pytree::arr;
using ::executorch::extension::pytree::ContainerHandle;
using ::executorch::extension::pytree::Key;
using ::executorch::extension::pytree::Kind;
using ::executorch::extension::pytree::unflatten;

using Leaf = int32_t;

TEST(PyTreeTest, ArrBasic) {
arr<int> x(5);
ASSERT_EQ(x.size(), 5);
for (int ii = 0; ii < x.size(); ++ii) {
x[ii] = 2 * ii;
}
int idx = 0;
for (const auto item : x) {
EXPECT_EQ(item, 2 * idx);
++idx;
}
}

TEST(PyTreeTest, List) {
Leaf items[2] = {11, 12};
std::string spec = "L2#1#1($,$)";
Expand Down

0 comments on commit 456928f

Please sign in to comment.