From 225785d9d406812de0d38ad8cddf0728436ef60c Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Tue, 31 Oct 2023 07:20:48 +0000 Subject: [PATCH] Added optimization of biased sampling --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index cfc679d01..83a850aa3 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -223,11 +223,23 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { - const auto index = at::multinomial(weight, count, replace); - const auto index_data = index.data_ptr(); - for (size_t i = 0; i < index.numel(); ++i) { - add(row_start + index_data[i], global_src_node, local_src_node, - dst_mapper, out_global_dst_nodes); + if (replace) { + const auto index = at::multinomial(weight, count, replace); + const auto index_data = index.data_ptr(); + for (size_t i = 0; i < index.numel(); ++i) { + add(row_start + index_data[i], global_src_node, local_src_node, + dst_mapper, out_global_dst_nodes); + } + } + else { + const auto rand = at::empty_like(weight).uniform_(); + const auto a = (rand.log() / weight); + const auto index = std::get<1>(a.topk(count)); + const auto index_data = index.data_ptr(); + for (size_t i = 0; i < index.numel(); ++i) { + add(row_start + index_data[i], global_src_node, local_src_node, + dst_mapper, out_global_dst_nodes); + } } } }