Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jan 13, 2025
1 parent a77f8f6 commit 558e450
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 135 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ file(GLOB_RECURSE ALL_SOURCES ${CSRC}/*.cpp)
if (WITH_CUDA)
file(GLOB_RECURSE ALL_SOURCES ${ALL_SOURCES} ${CSRC}/*.cu)
endif()
file(GLOB_RECURSE ALL_HEADERS ${CSRC}/*.h)
add_library(${PROJECT_NAME} SHARED ${ALL_SOURCES})
target_include_directories(${PROJECT_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
if(MKL_INCLUDE_FOUND)
Expand Down
82 changes: 0 additions & 82 deletions pyg_lib/csrc/classes/cpu/hash_map.cpp

This file was deleted.

44 changes: 0 additions & 44 deletions pyg_lib/csrc/classes/cpu/hash_map.h

This file was deleted.

67 changes: 67 additions & 0 deletions pyg_lib/csrc/classes/cpu/hash_map_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include "../hash_map_impl.h"
#include "parallel_hashmap/phmap.h"

namespace pyg {
namespace classes {

template <typename KeyType>
struct CPUHashMapImpl : HashMapImpl {
public:
using ValueType = int64_t;

CPUHashMapImpl(const at::Tensor& key) {
map_.reserve(key.numel());

const auto num_threads = at::get_num_threads();
const auto grain_size =
std::max((key.numel() + num_threads - 1) / num_threads,
at::internal::GRAIN_SIZE);
const auto key_data = key.data_ptr<KeyType>();

at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) {
for (int64_t i = beg; i < end; ++i) {
auto [iterator, inserted] = map_.insert({key_data[i], i});
TORCH_CHECK(inserted, "Found duplicated key in 'HashMap'.");
}
});
}

at::Tensor get(const at::Tensor& query) override {
const auto options = at::TensorOptions().dtype(at::kLong);
const auto out = at::empty({query.numel()}, options);
auto out_data = out.data_ptr<int64_t>();

const auto num_threads = at::get_num_threads();
const auto grain_size =
std::max((query.numel() + num_threads - 1) / num_threads,
at::internal::GRAIN_SIZE);
const auto query_data = query.data_ptr<int64_t>();

at::parallel_for(0, query.numel(), grain_size, [&](int64_t b, int64_t e) {
for (int64_t i = b; i < e; ++i) {
auto it = map_.find(query_data[i]);
out_data[i] = (it != map_.end()) ? it->second : -1;
}
});

return out;
}

private:
phmap::parallel_flat_hash_map<
KeyType,
ValueType,
phmap::priv::hash_default_hash<KeyType>,
phmap::priv::hash_default_eq<KeyType>,
phmap::priv::Allocator<std::pair<const KeyType, ValueType>>,
12,
std::mutex>
map_;
};

} // namespace classes
} // namespace pyg
47 changes: 47 additions & 0 deletions pyg_lib/csrc/classes/hash_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "hash_map.h"

#include <torch/library.h>
#include "cpu/hash_map_impl.h"

namespace pyg {
namespace classes {

HashMap::HashMap(const at::Tensor& key) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"HashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CPU);
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
key.scalar_type(),
"hash_map_init",
[&] {
/* if (key.is_cpu) { */
map_ = std::make_unique<CPUHashMapImpl<scalar_t>>(key);
/* } else { */
/* AT_ERROR("Received invalid device type for 'HashMap'."); */
/* } */
});
// clang-format on
}

at::Tensor HashMap::get(const at::Tensor& query) {
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"HashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CPU);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

return map_->get(query);
}

TORCH_LIBRARY(pyg, m) {
m.class_<HashMap>("HashMap")
.def(torch::init<at::Tensor&>())
.def("get", &HashMap::get);
}

} // namespace classes
} // namespace pyg
19 changes: 19 additions & 0 deletions pyg_lib/csrc/classes/hash_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <ATen/ATen.h>
#include "hash_map_impl.h"

namespace pyg {
namespace classes {

struct HashMap : torch::CustomClassHolder {
public:
HashMap(const at::Tensor& key);
at::Tensor get(const at::Tensor& query);

private:
std::unique_ptr<HashMapImpl> map_;
};

} // namespace classes
} // namespace pyg
14 changes: 14 additions & 0 deletions pyg_lib/csrc/classes/hash_map_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <ATen/ATen.h>

namespace pyg {
namespace classes {

struct HashMapImpl {
virtual ~HashMapImpl() = default;
virtual at::Tensor get(const at::Tensor& query) = 0;
};

} // namespace classes
} // namespace pyg
16 changes: 8 additions & 8 deletions test/csrc/classes/test_hash_map.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include <ATen/ATen.h>
#include <gtest/gtest.h>

#include "pyg_lib/csrc/classes/cpu/hash_map.h"
/* #include "pyg_lib/csrc/classes/hash_map.h" */

TEST(CPUHashMapTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);
auto key = at::tensor({0, 10, 30, 20}, options);
TEST(HashMapTest, BasicAssertions) {
/* auto options = at::TensorOptions().dtype(at::kLong); */
/* auto key = at::tensor({0, 10, 30, 20}, options); */

auto map = pyg::classes::CPUHashMap(key);
/* auto map = pyg::classes::HashMap(key); */

auto query = at::tensor({30, 10, 20, 40}, options);
auto expected = at::tensor({2, 1, 3, -1}, options);
EXPECT_TRUE(at::equal(map.get(query), expected));
/* auto query = at::tensor({30, 10, 20, 40}, options); */
/* auto expected = at::tensor({2, 1, 3, -1}, options); */
/* EXPECT_TRUE(at::equal(map.get(query), expected)); */
}

0 comments on commit 558e450

Please sign in to comment.