-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate how the gap between local and Dask predictions can be decreased #3835
Comments
Closed in favor of being in #2302. We decided to keep all feature requests in one place. Welcome to contribute this feature! Please re-open this issue (or post a comment if you are not a topic starter) if you are actively working on implementing this feature. |
I'd like to add that for this issue, we also need to investigate the gap for training with the same parameters and data using the non-Dask distributed training via the LightGBM CLI (https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html#run-parallel-learning). That will tell us which differences are due to the way the Dask module works, and which are features of LightGBM itself. The same is true for networking issues. At some point, I'm planning to add support for that in https://github.com/jameslamb/lightgbm-dask-testing so it's easy for anyone to test. |
Now that we have a way to test distributed training through the CLI I ran a comparison vs dask and found that the CLI always gives the same predictions but the dask ones can change. I see that when they differ they do from the first split in the first tree because the threshold is different. I made this experiment with two machines, each having one partition of the data and found that the difference is due to which worker is assigned as the master worker. Is this expected @shiyu1994? Sample script (ran from tests/distributed)import dask.array as da
import lightgbm as lgb
import numpy as np
from dask.distributed import Client, wait
from _test_distributed import create_data, DistributedMockup
if __name__ == '__main__':
data = create_data('regression')
num_machines = 2
partitions = np.array_split(data, num_machines)
train_params = {
'objective': 'regression',
'num_machines': num_machines,
}
cli = DistributedMockup('../../lightgbm')
cli.fit(partitions, train_params)
cli_preds = cli.predict()
dsk_params = {
'objective': 'regression',
'tree_learner': 'data',
'force_row_wise': True,
'verbose': 0,
'num_boost_round': 20,
'num_leaves': 15,
}
client = Client(n_workers=2, threads_per_worker=2)
dsk_data = da.vstack(partitions)
dX = dsk_data[:, 1:]
dy = dsk_data[:, 0]
dX, dy = client.persist([dX, dy])
_ = wait([dX, dy])
client.rebalance()
dsk = lgb.DaskLGBMRegressor(**dsk_params)
dsk.fit(dX, dy)
dsk_preds = dsk.predict(dX).compute()
try:
np.testing.assert_equal(dsk_preds, cli_preds)
except AssertionError:
cols = ['tree_index', 'node_depth', 'split_feature', 'split_gain', 'threshold']
cli_bst = lgb.Booster(model_file='model0.txt')
print('CLI')
print(cli_bst.trees_to_dataframe().head(1)[cols])
dsk_bst = dsk.booster_
print('Dask')
print(dsk_bst.trees_to_dataframe().head(1)[cols]) When they differ I see this:
Reversing the order in |
@jmoralez Sorry for the late response. I suppose so IF the data partition are not identical in the CLI and Dask versions. Otherwise, I need more efforts to investigate where the difference came from. In both CLI and Dask version of the above example, they should both use For example, in the CLI case, each process will have LightGBM/src/io/dataset_loader.cpp Lines 716 to 719 in 2394b41
I haven't figure out how LightGBM Dask finds the bin boundaries yet. But given that the CLI distributed version uses only local data for bin boundary finding when Could you please point me the code where I can find how Dask partitions the data? Thanks! |
The partitions are identical, since I create them manually in the script. What I meant is that even with the CLI version alone, the results differ depending on the order of the machines. So in this case we have two machines A, B and two partitions P1 and P2. A always gets P1 and B always gets P2, however depending on which machine is first on Here's a smaller example: gen_data.pyimport numpy as np
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=1_000, n_features=4, n_informative=2, random_state=42)
data = np.hstack([y.reshape(-1, 1), X])
partitions = np.array_split(data, 2)
for i, partition in enumerate(partitions):
np.savetxt(f'train{i}.txt', partition, delimiter=',') train0.conf
train1.conf
mlist.txt
If you run this with the first machine in |
@jmoralez, @shiyu1994 's comment gave me an idea. I think we should always set LightGBM/include/LightGBM/config.h Lines 641 to 644 in dac0dff
I think that since that interface assumes that Dask is handling data distribution (through the way your Dask DataFrame or Dask Array is partitioned), from LightGBM's perspective What do you think? To be clear, I don't think such a change would change what you observed in #3835 (comment), since the config value is only used when loading a dataset from a file (see these links I found with
LightGBM/src/io/dataset_loader.cpp Line 184 in aab212a
LightGBM/src/io/dataset_loader.cpp Line 520 in aab212a
LightGBM/src/io/dataset_loader.cpp Line 863 in aab212a
LightGBM/src/io/dataset_loader.cpp Line 926 in aab212a
|
@jmoralez When using data distributed training, bin finding of different features are distributed to the machines by feature. So in your example, machine with rank 0 will always finds the bin mappers for the first part of the features, and the machine with rank 1 will always finds the bin mappers for the second part of the features. Both with their local data. So I think the difference makes sense, because with different rank numbers, different data will be used to find the bin mappers of the same feature. |
@jameslamb I agree. Although we should probably not let the user change this so maybe we should add this in LightGBM/python-package/lightgbm/dask.py Line 388 in b1facf5
Thank you for the explanation @shiyu1994. I think that may be the cause of the differences we see in the distributed tests. Will report back soon after I run some tests. |
I've been running some more experiments and it definitely seems that the differences we sometimes observe in the predictions between local and distributed are due to the slightly different thresholds, because when a sample is very close to the thresholds it can end up in the opposite direction of a split with respect to the other model. The most dramatic case is with two leaves only, where even though the trees are almost identical, a sample with Sample codeimport dask.dataframe as dd
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dask.distributed import Client
from sklearn.datasets import make_regression
client = Client(n_workers=2, threads_per_worker=1)
X, y = make_regression(n_samples=1_000, n_features=4, n_informative=2, random_state=42)
X = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(X.shape[1])])
y = pd.Series(y)
dX = dd.from_pandas(X, npartitions=2)
dy = dd.from_pandas(y, npartitions=2)
params = {
"boosting_type": "gbdt",
"random_state": 42,
"num_leaves": 2,
"n_estimators": 20,
"tree_learner": "data",
}
local_reg = lgb.LGBMRegressor(**params)
local_reg.fit(X, y)
local_preds = local_reg.predict(X)
dask_reg = lgb.DaskLGBMRegressor(**params)
dask_reg.fit(dX, dy)
dask_preds = dask_reg.predict(dX).compute()
pct_diffs = abs(local_preds - dask_preds) / local_preds
top_diff_idx = np.argsort(-pct_diffs)[0]
x = X.iloc[top_diff_idx]
fig, ax = plt.subplots(ncols=2, figsize=(10, 6))
for reg, title, axi in zip([local_reg, dask_reg], ['local', 'distributed'], ax.flat):
lgb.plot_tree(reg, ax=axi, show_info=['internal_value', 'leaf_count'], tree_index=0, precision=5)
axi.set_title(title); With a higher number of leaves the differences are smaller but they increase as the iterations go by, for example using 8 leaves with the code above the prediction per iteration for a specific sample looks like this: Additional code for plotlocal_result = []
dask_result = []
dask_bst = dask_reg.booster_
for i in range(1, 21):
local_result.append(local_regressor.predict(x, num_iteration=i)[0])
dask_result.append(dask_bst.predict(x, num_iteration=i)[0])
fig, ax = plt.subplots(figsize=(6, 4))
pd.DataFrame({'local': local_result, 'dask': dask_result}).plot(marker='.', ax=ax) From the plot above it can be seen that at tree index 6 the predictions become very different. The structure of the built trees and the path the sample takes through them can be seen in the image below, as well as the predicted value. We can see here that the slight differences seem to accumulate and by tree index 5 the tree structures start to differ. |
Nice investigation @jmoralez ! I think this still makes sense given @shiyu1994 's explanation in #3835 (comment). Since the Dask code is splitting To eliminate that specific source of difference and check for other types (like maybe loss of precision during collective operations like syncing histograms), you could try create a Dask DataFrame with two partitions where each partition is a full copy of |
@jmoralez Nice investigation. I think the only way to eliminate the difference is by synchronizing the processes when finding the bin boundaries. But that would require designing a distributed algorithm for bin finding, which is nontrivial. |
Summary
Right now in our CI tests difference between local and Dask-based estimators are quite big in terms of produced results after training with the same set of params. It would be good to investigate reasons and either provide source code fixes or write some guide how the difference can be mitigated on a user side.
References
#3515 (comment)
LightGBM/tests/python_package_test/test_dask.py
Line 161 in ac706e1
LightGBM/tests/python_package_test/test_dask.py
Lines 206 to 207 in ac706e1
LightGBM/tests/python_package_test/test_dask.py
Lines 304 to 306 in ac706e1
LightGBM/tests/python_package_test/test_dask.py
Lines 436 to 440 in ac706e1
The text was updated successfully, but these errors were encountered: