Skip to content

Commit

Permalink
Add Python binding for rabit ops. (#5743)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jun 2, 2020
1 parent e533908 commit e49607a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
14 changes: 14 additions & 0 deletions python-package/xgboost/rabit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def get_world_size():
return ret


def is_distributed():
'''If rabit is distributed.'''
is_dist = _LIB.RabitIsDistributed()
return is_dist


def tracker_print(msg):
"""Print message to the tracker.
Expand Down Expand Up @@ -143,6 +149,14 @@ def broadcast(data, root):
}


class Op: # pylint: disable=too-few-public-methods
'''Supported operations for rabit.'''
MAX = 0
MIN = 1
SUM = 2
OR = 3


def allreduce(data, op, prepare_fun=None):
"""Perform allreduce, return the result.
Expand Down
41 changes: 39 additions & 2 deletions tests/python/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time

from xgboost import RabitTracker
import xgboost as xgb
import pytest
import testing as tm
import numpy as np


def test_rabit_tracker():
Expand All @@ -15,3 +16,39 @@ def test_rabit_tracker():
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
xgb.rabit.finalize()


def run_rabit_ops(client, n_workers):
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
from xgboost import rabit

workers = list(_get_client_workers(client).keys())
rabit_args = _get_rabit_args(workers, client)
assert not rabit.is_distributed()

def local_test(worker_id):
with RabitContext(rabit_args):
a = 1
assert rabit.is_distributed()
a = np.array([a])
reduced = rabit.allreduce(a, rabit.Op.SUM)
assert reduced[0] == n_workers

worker_id = np.array([worker_id])
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
assert reduced == n_workers - 1

return 1

futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
assert sum(results) == n_workers


@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)

0 comments on commit e49607a

Please sign in to comment.