-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
155 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/Parallel.h> | ||
#include "../hash_map_impl.h" | ||
#include "parallel_hashmap/phmap.h" | ||
|
||
namespace pyg { | ||
namespace classes { | ||
|
||
template <typename KeyType> | ||
struct CPUHashMapImpl : HashMapImpl { | ||
public: | ||
using ValueType = int64_t; | ||
|
||
CPUHashMapImpl(const at::Tensor& key) { | ||
map_.reserve(key.numel()); | ||
|
||
const auto num_threads = at::get_num_threads(); | ||
const auto grain_size = | ||
std::max((key.numel() + num_threads - 1) / num_threads, | ||
at::internal::GRAIN_SIZE); | ||
const auto key_data = key.data_ptr<KeyType>(); | ||
|
||
at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) { | ||
for (int64_t i = beg; i < end; ++i) { | ||
auto [iterator, inserted] = map_.insert({key_data[i], i}); | ||
TORCH_CHECK(inserted, "Found duplicated key in 'HashMap'."); | ||
} | ||
}); | ||
} | ||
|
||
at::Tensor get(const at::Tensor& query) override { | ||
const auto options = at::TensorOptions().dtype(at::kLong); | ||
const auto out = at::empty({query.numel()}, options); | ||
auto out_data = out.data_ptr<int64_t>(); | ||
|
||
const auto num_threads = at::get_num_threads(); | ||
const auto grain_size = | ||
std::max((query.numel() + num_threads - 1) / num_threads, | ||
at::internal::GRAIN_SIZE); | ||
const auto query_data = query.data_ptr<int64_t>(); | ||
|
||
at::parallel_for(0, query.numel(), grain_size, [&](int64_t b, int64_t e) { | ||
for (int64_t i = b; i < e; ++i) { | ||
auto it = map_.find(query_data[i]); | ||
out_data[i] = (it != map_.end()) ? it->second : -1; | ||
} | ||
}); | ||
|
||
return out; | ||
} | ||
|
||
private: | ||
phmap::parallel_flat_hash_map< | ||
KeyType, | ||
ValueType, | ||
phmap::priv::hash_default_hash<KeyType>, | ||
phmap::priv::hash_default_eq<KeyType>, | ||
phmap::priv::Allocator<std::pair<const KeyType, ValueType>>, | ||
12, | ||
std::mutex> | ||
map_; | ||
}; | ||
|
||
} // namespace classes | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#include "hash_map.h" | ||
|
||
#include <torch/library.h> | ||
#include "cpu/hash_map_impl.h" | ||
|
||
namespace pyg { | ||
namespace classes { | ||
|
||
HashMap::HashMap(const at::Tensor& key) { | ||
at::TensorArg key_arg{key, "key", 0}; | ||
at::CheckedFrom c{"HashMap.init"}; | ||
at::checkDeviceType(c, key, at::DeviceType::CPU); | ||
at::checkDim(c, key_arg, 1); | ||
at::checkContiguous(c, key_arg); | ||
|
||
// clang-format off | ||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, | ||
key.scalar_type(), | ||
"hash_map_init", | ||
[&] { | ||
/* if (key.is_cpu) { */ | ||
map_ = std::make_unique<CPUHashMapImpl<scalar_t>>(key); | ||
/* } else { */ | ||
/* AT_ERROR("Received invalid device type for 'HashMap'."); */ | ||
/* } */ | ||
}); | ||
// clang-format on | ||
} | ||
|
||
at::Tensor HashMap::get(const at::Tensor& query) { | ||
at::TensorArg query_arg{query, "query", 0}; | ||
at::CheckedFrom c{"HashMap.get"}; | ||
at::checkDeviceType(c, query, at::DeviceType::CPU); | ||
at::checkDim(c, query_arg, 1); | ||
at::checkContiguous(c, query_arg); | ||
|
||
return map_->get(query); | ||
} | ||
|
||
TORCH_LIBRARY(pyg, m) { | ||
m.class_<HashMap>("HashMap") | ||
.def(torch::init<at::Tensor&>()) | ||
.def("get", &HashMap::get); | ||
} | ||
|
||
} // namespace classes | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include "hash_map_impl.h" | ||
|
||
namespace pyg { | ||
namespace classes { | ||
|
||
struct HashMap : torch::CustomClassHolder { | ||
public: | ||
HashMap(const at::Tensor& key); | ||
at::Tensor get(const at::Tensor& query); | ||
|
||
private: | ||
std::unique_ptr<HashMapImpl> map_; | ||
}; | ||
|
||
} // namespace classes | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
|
||
namespace pyg { | ||
namespace classes { | ||
|
||
struct HashMapImpl { | ||
virtual ~HashMapImpl() = default; | ||
virtual at::Tensor get(const at::Tensor& query) = 0; | ||
}; | ||
|
||
} // namespace classes | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
#include <ATen/ATen.h> | ||
#include <gtest/gtest.h> | ||
|
||
#include "pyg_lib/csrc/classes/cpu/hash_map.h" | ||
/* #include "pyg_lib/csrc/classes/hash_map.h" */ | ||
|
||
TEST(CPUHashMapTest, BasicAssertions) { | ||
auto options = at::TensorOptions().dtype(at::kLong); | ||
auto key = at::tensor({0, 10, 30, 20}, options); | ||
TEST(HashMapTest, BasicAssertions) { | ||
/* auto options = at::TensorOptions().dtype(at::kLong); */ | ||
/* auto key = at::tensor({0, 10, 30, 20}, options); */ | ||
|
||
auto map = pyg::classes::CPUHashMap(key); | ||
/* auto map = pyg::classes::HashMap(key); */ | ||
|
||
auto query = at::tensor({30, 10, 20, 40}, options); | ||
auto expected = at::tensor({2, 1, 3, -1}, options); | ||
EXPECT_TRUE(at::equal(map.get(query), expected)); | ||
/* auto query = at::tensor({30, 10, 20, 40}, options); */ | ||
/* auto expected = at::tensor({2, 1, 3, -1}, options); */ | ||
/* EXPECT_TRUE(at::equal(map.get(query), expected)); */ | ||
} |