Skip to content

Commit

Permalink
Added optimization of biased sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
OlhaBabicheva committed Oct 31, 2023
1 parent 9fc1afc commit 225785d
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,23 @@ class NeighborSampler {

// Case 2: Multinomial sampling:
else {
const auto index = at::multinomial(weight, count, replace);
const auto index_data = index.data_ptr<int64_t>();
for (size_t i = 0; i < index.numel(); ++i) {
add(row_start + index_data[i], global_src_node, local_src_node,
dst_mapper, out_global_dst_nodes);
if (replace) {
const auto index = at::multinomial(weight, count, replace);
const auto index_data = index.data_ptr<int64_t>();
for (size_t i = 0; i < index.numel(); ++i) {
add(row_start + index_data[i], global_src_node, local_src_node,

Check warning on line 230 in pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp#L227-L230

Added lines #L227 - L230 were not covered by tests
dst_mapper, out_global_dst_nodes);
}
}

Check warning on line 233 in pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp#L233

Added line #L233 was not covered by tests
else {
const auto rand = at::empty_like(weight).uniform_();
const auto a = (rand.log() / weight);
const auto index = std::get<1>(a.topk(count));
const auto index_data = index.data_ptr<int64_t>();
for (size_t i = 0; i < index.numel(); ++i) {
add(row_start + index_data[i], global_src_node, local_src_node,
dst_mapper, out_global_dst_nodes);
}
}
}
}
Expand Down

0 comments on commit 225785d

Please sign in to comment.