Skip to content

Commit

Permalink
Update DistNeighborSampler for hetero (#8503)
Browse files Browse the repository at this point in the history
The purpose of this PR is to improve distributed hetero sampling
algorithm.
**IMPORTANT INFO**: This PR is complementary with
[#284](pyg-team/pyg-lib#284) from pyg-lib. The
pyg-lib one needs to be merged for this one to work properly.


**Description:** (sorry if too long)
Distributed hetero neighbor sampling is a procedure analogous to homo
sampling, but more complicated due to the presence of different types of
nodes and edges.
Sampling in distributed training imitates the `hetero_neighbor_sample()`
function in pyg-lib. Therefore, the mechanism of action and the
nomenclature of variables are similar.
Due to the fact that in distributed training, after sampling each layer,
it is necessary to synchronize the results between machines, the loop
iterating through the layers was implemented in Python.

The main two loops iterate sequentially: over layers and edge types.
Inside the loop, the `sample_one_hop()` function is called, which
performs sampling for one layer.
The input to the `sample_one_hop()` function is data of a specific type,
so its execution is almost identical to homo.
The sample_one_hop() function, depending on whether the input nodes are
located on a given partition or a remote one, performs sampling or sends
an RPC request to the remote machine to do so. The
`dist_neighbor_sample()`->`neighbor_sample()` function is used for
sampling. Nodes are sampled with duplicates so that they can later be
used to construct local to global node mappings.
When all machines have finished sampling, their outputs are merged and
synchronized in the same way as for homo.
Then the results return to the `node_sample()` function where they are
written to the output dictionaries and the src nodes for the next layer
are calculated.
After going through all the layers, the global node indices are finally
mapped to the local ones in the `hetero_dist_relabel()` function.

Information about some of the variables used in a node_sample()
function:
`node_dict` - class storing information about nodes. It has three
fields: `src`, `with_dupl`, `out`, which are described in more detail in
the distributed/utils.py file.
`batch_dict` - class used when sampling with the disjoint option. It
stores information about the affiliation of nodes to subgraphs. Just
like `node_dict`, it has three fields: `src`, `with_dupl`, `out`.
`sampled_nbrs_per_node_dict` - a dictionary that stores information
about the number of sampled neighbors by each src node. To facilitate
subsequent operations, for each edge type is additionally divided into
layers.
`num_sampled_nodes_dict`, `num_sampled_edges_dict` - needed for HGAM to
work.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
kgajdamo and rusty1s authored Dec 14, 2023
1 parent 331e5d1 commit c5bf8ef
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 142 deletions.
3 changes: 1 addition & 2 deletions test/distributed/test_dist_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def dist_neighbor_loader_hetero(
for edge_type in batch.edge_types:
num_edges = batch[edge_type].edge_index.size(1)

assert batch[edge_type].edge_attr.size(0) == num_edges
if num_edges > 0: # Test edge mapping:
assert batch[edge_type].edge_attr.size(0) == num_edges
src, _, dst = edge_type
edge_index = part_data[1]._edge_index[(edge_type, "coo")]
global_edge_index_1 = torch.stack([
Expand Down Expand Up @@ -209,7 +209,6 @@ def test_dist_neighbor_loader_homo(
@pytest.mark.parametrize('num_parts', [2])
@pytest.mark.parametrize('num_workers', [0])
@pytest.mark.parametrize('async_sampling', [True])
@pytest.mark.skip(reason="Breaks with no attribute 'num_hops'")
def test_dist_neighbor_loader_hetero(
tmp_path,
num_parts,
Expand Down
164 changes: 151 additions & 13 deletions test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
import torch

from torch_geometric.data import Data
from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore
from torch_geometric.datasets import FakeHeteroDataset
from torch_geometric.distributed import (
LocalFeatureStore,
LocalGraphStore,
Partitioner,
)
from torch_geometric.distributed.dist_context import DistContext
from torch_geometric.distributed.dist_neighbor_sampler import (
DistNeighborSampler,
close_sampler,
)
from torch_geometric.distributed.partition import load_partition_info
from torch_geometric.distributed.rpc import init_rpc
from torch_geometric.sampler import NeighborSampler, NodeSamplerInput
from torch_geometric.sampler.neighbor_sampler import node_sample
Expand Down Expand Up @@ -69,6 +75,21 @@ def create_data(rank: int, world_size: int, time_attr: Optional[str] = None):
return (feature_store, graph_store), data


def create_hetero_data(tmp_path: str, rank: int):
graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)
feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)

out = load_partition_info(tmp_path, rank)

feature_store.meta = graph_store.meta = out[0]
feature_store.num_partitions = graph_store.num_partitions = out[1]
feature_store.partition_idx = graph_store.partition_idx = out[2]
feature_store.node_feat_pb = graph_store.node_pb = out[3]
feature_store.edge_feat_pb = graph_store.edge_pb = out[4]

return feature_store, graph_store


def dist_neighbor_sampler(
world_size: int,
rank: int,
Expand Down Expand Up @@ -211,15 +232,94 @@ def dist_neighbor_sampler_temporal(
assert out_dist.num_sampled_edges == out.num_sampled_edges


def dist_neighbor_sampler_hetero(
data: FakeHeteroDataset,
tmp_path: str,
world_size: int,
rank: int,
master_port: int,
input_type: str,
disjoint: bool = False,
):
dist_data = create_hetero_data(tmp_path, rank)

current_ctx = DistContext(
rank=rank,
global_rank=rank,
world_size=world_size,
global_world_size=world_size,
group_name='dist-sampler-test',
)

num_neighbors = [-1, -1]
dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=num_neighbors,
shuffle=False,
disjoint=disjoint,
)

# Close RPC & worker group at exit:
atexit.register(close_sampler, 0, dist_sampler)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.register_sampler_rpc()
dist_sampler.init_event_loop()

# Create inputs nodes such that each belongs to a different partition:
node_pb_list = dist_data[1].node_pb[input_type].tolist()
node_0 = node_pb_list.index(0)
node_1 = node_pb_list.index(1)

input_node = torch.tensor([node_0, node_1], dtype=torch.int64)

inputs = NodeSamplerInput(
input_id=None,
node=input_node,
input_type=input_type,
)

# Evaluate distributed node sample function:
out_dist = dist_sampler.event_loop.run_task(
coro=dist_sampler.node_sample(inputs))

sampler = NeighborSampler(
data=data,
num_neighbors=num_neighbors,
disjoint=disjoint,
)

# Evaluate node sample function:
out = node_sample(inputs, sampler._sample)

# Compare distributed output with single machine output:
for k in data.node_types:
assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0])
assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k]
if disjoint:
assert torch.equal(
out_dist.batch[k].sort()[0],
out.batch[k].sort()[0],
)


@onlyLinux
@withPackage('pyg_lib')
@pytest.mark.parametrize('disjoint', [False, True])
def test_dist_neighbor_sampler(disjoint):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
w0 = mp_context.Process(
Expand All @@ -244,10 +344,10 @@ def test_dist_neighbor_sampler(disjoint):
@pytest.mark.parametrize('temporal_strategy', ['uniform'])
def test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
w0 = mp_context.Process(
Expand Down Expand Up @@ -277,10 +377,10 @@ def test_dist_neighbor_sampler_edge_level_temporal(
seed_time = torch.tensor(seed_time)

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
w0 = mp_context.Process(
Expand All @@ -297,3 +397,41 @@ def test_dist_neighbor_sampler_edge_level_temporal(
w1.start()
w0.join()
w1.join()


@withPackage('pyg_lib')
@pytest.mark.parametrize('disjoint', [False, True])
def test_dist_neighbor_sampler_hetero(tmp_path, disjoint):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
data = FakeHeteroDataset(
num_graphs=1,
avg_num_nodes=100,
avg_degree=3,
num_node_types=2,
num_edge_types=4,
edge_dim=2,
)[0]

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

w0 = mp_context.Process(
target=dist_neighbor_sampler_hetero,
args=(data, tmp_path, world_size, 0, port, 'v0', disjoint),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_hetero,
args=(data, tmp_path, world_size, 1, port, 'v1', disjoint),
)

w0.start()
w1.start()
w0.join()
w1.join()
Loading

0 comments on commit c5bf8ef

Please sign in to comment.