Skip to content

Commit

Permalink
hetero dist relabel update
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Dec 1, 2023
1 parent a5fcc87 commit c65a353
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 68 deletions.
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ merge_outputs(
}

const auto p_size = partition_ids.size();
std::vector<int64_t> sampled_neighbors_per_node(p_size);
std::vector<int64_t> num_sampled_neighbors_per_node(p_size);

const auto scalar_type = node_ids[0].scalar_type();
AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] {
Expand Down Expand Up @@ -106,7 +106,7 @@ merge_outputs(
batch_data[j]);
}

sampled_neighbors_per_node[j] = end_node - begin_node;
num_sampled_neighbors_per_node[j] = end_node - begin_node;
}
});

Expand All @@ -128,7 +128,7 @@ merge_outputs(
});

return std::make_tuple(out_node_id, out_edge_id, out_batch,
sampled_neighbors_per_node);
num_sampled_neighbors_per_node);
}

#define DISPATCH_MERGE_OUTPUTS(disjoint, ...) \
Expand Down
111 changes: 77 additions & 34 deletions pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ relabel(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
Expand All @@ -117,9 +117,16 @@ relabel(
phmap::flat_hash_map<node_type, scalar_t*> batch_data_dict;
phmap::flat_hash_map<edge_type, std::vector<scalar_t>> sampled_rows_dict;
phmap::flat_hash_map<edge_type, std::vector<scalar_t>> sampled_cols_dict;
// `srcs_slice_dict` defines the number of src nodes for each edge type in
// a given layer in the form of a range. Local src nodes (`sampled_rows`)
// will be created on its basis, so for a given edge type the ranges will
// not be repeated, and the starting value of the next layer will be the
// end value from the previous layer.
phmap::flat_hash_map<edge_type, std::pair<size_t, size_t>> srcs_slice_dict;

phmap::flat_hash_map<node_type, Mapper<node_t, scalar_t>> mapper_dict;
phmap::flat_hash_map<node_type, std::pair<size_t, size_t>> slice_dict;
phmap::flat_hash_map<node_type, int64_t> srcs_offset_dict;

const bool parallel = at::get_num_threads() > 1 && edge_types.size() > 1;
std::vector<std::vector<edge_type>> threads_edge_types;
Expand All @@ -129,6 +136,14 @@ relabel(
sampled_rows_dict[k];
sampled_cols_dict[k];

// `num_sampled_neighbors_per_node_dict` is a dictionary where for
// each edge type it contains information about how many neighbors every
// src node has sampled. These values are saved in a separate vector for
// each layer.
size_t num_src_nodes =
num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[0].size();
srcs_slice_dict[k] = {0, num_src_nodes};

if (parallel) {
// Each thread is assigned edge types that have the same dst node
// type. Thanks to this, each thread will operate on a separate
Expand Down Expand Up @@ -161,6 +176,7 @@ relabel(
{k, sampled_nodes_with_duplicates_dict.at(k).data_ptr<scalar_t>()});
mapper_dict.insert({k, Mapper<node_t, scalar_t>(N)});
slice_dict[k] = {0, 0};
srcs_offset_dict[k] = 0;
if constexpr (disjoint) {
batch_data_dict.insert(
{k, batch_dict.value().at(k).data_ptr<scalar_t>()});
Expand All @@ -178,44 +194,71 @@ relabel(
}
}
}
at::parallel_for(
0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) {
for (auto j = _s; j < _e; j++) {
for (const auto& k : threads_edge_types[j]) {
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);

const auto num_sampled_neighbors_size =
num_sampled_neighbors_per_node_dict.at(to_rel_type(k)).size();

if (num_sampled_neighbors_size == 0) {
continue;
}

for (auto i = 0; i < num_sampled_neighbors_size; i++) {
auto& dst_mapper = mapper_dict.at(dst);
auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst);

slice_dict.at(dst).second +=
num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[i];
auto [begin, end] = slice_dict.at(dst);

for (auto j = begin; j < end; j++) {
std::pair<scalar_t, bool> res;
if constexpr (!disjoint) {
res = dst_mapper.insert(dst_sampled_nodes_data[j]);
} else {
res = dst_mapper.insert({batch_data_dict.at(dst)[j],
dst_sampled_nodes_data[j]});
size_t num_layers =
num_sampled_neighbors_per_node_dict.at(to_rel_type(edge_types[0]))
.size();
// Iterate over the layers
for (auto ell = 0; ell < num_layers; ++ell) {
at::parallel_for(
0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) {
for (auto t = _s; t < _e; t++) {
for (const auto& k : threads_edge_types[t]) {
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);

auto [src_begin, src_end] = srcs_slice_dict.at(k);

for (auto i = src_begin; i < src_end; i++) {
auto& dst_mapper = mapper_dict.at(dst);
auto& dst_sampled_nodes_data =
sampled_nodes_data_dict.at(dst);

// For each edge type `slice_dict` defines the number of
// nodes sampled by a src node `i` in the form of a range.
// The indices in the given range point to global dst nodes
// from `dst_sampled_nodes_data`.
slice_dict.at(dst).second +=
num_sampled_neighbors_per_node_dict.at(
to_rel_type(k))[ell][i - src_begin];
auto [begin, end] = slice_dict.at(dst);

for (auto j = begin; j < end; j++) {
std::pair<scalar_t, bool> res;
if constexpr (!disjoint) {
res = dst_mapper.insert(dst_sampled_nodes_data[j]);
} else {
res = dst_mapper.insert({batch_data_dict.at(dst)[j],
dst_sampled_nodes_data[j]});
}
sampled_rows_dict.at(k).push_back(i);
sampled_cols_dict.at(k).push_back(res.first);
}
sampled_rows_dict.at(k).push_back(i);
sampled_cols_dict.at(k).push_back(res.first);
slice_dict.at(dst).first = end;
}
slice_dict.at(dst).first = end;
}
}
});

// Get local src nodes ranges for the next layer
if (ell < num_layers - 1) {
for (const auto& k : edge_types) {
// Edges with the same src node types will have the same src node
// offsets.
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
if (srcs_offset_dict[src] < srcs_slice_dict.at(k).second) {
srcs_offset_dict[src] = srcs_slice_dict.at(k).second;
}
});
}
for (const auto& k : edge_types) {
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
srcs_slice_dict[k] = {
srcs_offset_dict.at(src),
srcs_offset_dict.at(src) + num_sampled_neighbors_per_node_dict
.at(to_rel_type(k))[ell + 1]
.size()};
}
}
}

for (const auto& k : edge_types) {
const auto edges = get_sampled_edges<scalar_t>(
Expand Down Expand Up @@ -254,7 +297,7 @@ hetero_relabel_neighborhood_kernel(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
Expand Down
20 changes: 11 additions & 9 deletions pyg_lib/csrc/sampler/dist_relabel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace sampler {
std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
const at::Tensor& seed,
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& sampled_neighbors_per_node,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const c10::optional<at::Tensor>& batch,
bool csc,
Expand All @@ -28,7 +28,8 @@ std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
.findSchemaOrThrow("pyg::relabel_neighborhood", "")
.typed<decltype(relabel_neighborhood)>();
return op.call(seed, sampled_nodes_with_duplicates,
sampled_neighbors_per_node, num_nodes, batch, csc, disjoint);
num_sampled_neighbors_per_node, num_nodes, batch, csc,
disjoint);
}

std::tuple<c10::Dict<rel_type, at::Tensor>, c10::Dict<rel_type, at::Tensor>>
Expand All @@ -37,8 +38,8 @@ hetero_relabel_neighborhood(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
sampled_neighbors_per_node_dict,
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
bool csc,
Expand All @@ -62,21 +63,22 @@ hetero_relabel_neighborhood(
.typed<decltype(hetero_relabel_neighborhood)>();
return op.call(node_types, edge_types, seed_dict,
sampled_nodes_with_duplicates_dict,
sampled_neighbors_per_node_dict, num_nodes_dict, batch_dict,
csc, disjoint);
num_sampled_neighbors_per_node_dict, num_nodes_dict,
batch_dict, csc, disjoint);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::relabel_neighborhood(Tensor seed, Tensor "
"sampled_nodes_with_duplicates, int[] sampled_neighbors_per_node, int "
"sampled_nodes_with_duplicates, int[] num_sampled_neighbors_per_node, "
"int "
"num_nodes, Tensor? batch = None, bool csc = False, bool disjoint = "
"False) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_relabel_neighborhood(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) seed_dict, Dict(str, Tensor) "
"sampled_nodes_with_duplicates_dict, Dict(str, int[]) "
"sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, "
"sampled_nodes_with_duplicates_dict, Dict(str, int[][]) "
"num_sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, "
"Dict(str, Tensor)? batch_dict = None, bool csc = False, bool disjoint = "
"False) -> (Dict(str, Tensor), Dict(str, Tensor))"));
}
Expand Down
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/dist_relabel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ PYG_API
std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
const at::Tensor& seed,
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& sampled_neighbors_per_node,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const c10::optional<at::Tensor>& batch = c10::nullopt,
bool csc = false,
Expand All @@ -32,8 +32,8 @@ hetero_relabel_neighborhood(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
sampled_neighbors_per_node_dict,
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict =
c10::nullopt,
Expand Down
15 changes: 9 additions & 6 deletions test/csrc/sampler/test_dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ TEST(DistMergeOutputsTest, BasicAssertions) {
auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 2};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
2};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}

TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
Expand Down Expand Up @@ -82,8 +83,9 @@ TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20, 21}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 3};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
3};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}

TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
Expand Down Expand Up @@ -124,6 +126,7 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
auto expected_batch = at::tensor({0, 0, 1, 2, 2, 3, 3}, options);
EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_batch));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 2};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
2};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}
Loading

0 comments on commit c65a353

Please sign in to comment.