Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update
DistNeighborSampler
for hetero (#8503)
The purpose of this PR is to improve distributed hetero sampling algorithm. **IMPORTANT INFO**: This PR is complementary with [#284](pyg-team/pyg-lib#284) from pyg-lib. The pyg-lib one needs to be merged for this one to work properly. **Description:** (sorry if too long) Distributed hetero neighbor sampling is a procedure analogous to homo sampling, but more complicated due to the presence of different types of nodes and edges. Sampling in distributed training imitates the `hetero_neighbor_sample()` function in pyg-lib. Therefore, the mechanism of action and the nomenclature of variables are similar. Due to the fact that in distributed training, after sampling each layer, it is necessary to synchronize the results between machines, the loop iterating through the layers was implemented in Python. The main two loops iterate sequentially: over layers and edge types. Inside the loop, the `sample_one_hop()` function is called, which performs sampling for one layer. The input to the `sample_one_hop()` function is data of a specific type, so its execution is almost identical to homo. The sample_one_hop() function, depending on whether the input nodes are located on a given partition or a remote one, performs sampling or sends an RPC request to the remote machine to do so. The `dist_neighbor_sample()`->`neighbor_sample()` function is used for sampling. Nodes are sampled with duplicates so that they can later be used to construct local to global node mappings. When all machines have finished sampling, their outputs are merged and synchronized in the same way as for homo. Then the results return to the `node_sample()` function where they are written to the output dictionaries and the src nodes for the next layer are calculated. After going through all the layers, the global node indices are finally mapped to the local ones in the `hetero_dist_relabel()` function. Information about some of the variables used in a node_sample() function: `node_dict` - class storing information about nodes. It has three fields: `src`, `with_dupl`, `out`, which are described in more detail in the distributed/utils.py file. `batch_dict` - class used when sampling with the disjoint option. It stores information about the affiliation of nodes to subgraphs. Just like `node_dict`, it has three fields: `src`, `with_dupl`, `out`. `sampled_nbrs_per_node_dict` - a dictionary that stores information about the number of sampled neighbors by each src node. To facilitate subsequent operations, for each edge type is additionally divided into layers. `num_sampled_nodes_dict`, `num_sampled_edges_dict` - needed for HGAM to work. --------- Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information