Skip to content

Commit

Permalink
[dask] remove unused private _client attribute (#3904)
Browse files Browse the repository at this point in the history
* Update test_dask.py

* Update dask.py

* Update .vsts-ci.yml

* Revert "Update .vsts-ci.yml"

This reverts commit 98422be.
  • Loading branch information
StrikerRUS authored Feb 3, 2021
1 parent 08c68c9 commit b1e000c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 27 deletions.
10 changes: 2 additions & 8 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,9 @@ def client_(self) -> Client:
def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None)
out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None})
self._client = client
out.update({"client": None})
self.client = client
return out

Expand Down Expand Up @@ -521,8 +519,7 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names:
if name != "_client":
setattr(dest, name, attributes[name])
setattr(dest, name, attributes[name])


class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
Expand Down Expand Up @@ -554,7 +551,6 @@ def __init__(
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down Expand Up @@ -672,7 +668,6 @@ def __init__(
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down Expand Up @@ -779,7 +774,6 @@ def __init__(
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down
20 changes: 1 addition & 19 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Tests for lightgbm.dask module"""

import inspect
import joblib
import pickle
import socket
from itertools import groupby
Expand All @@ -19,6 +18,7 @@
import cloudpickle
import dask.array as da
import dask.dataframe as dd
import joblib
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
Expand Down Expand Up @@ -488,56 +488,47 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l

# should be able to use the class without specifying a client
dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_

dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client

preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client

local_model = dask_model.to_local()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_

# should be able to set client after construction
dask_model = model_factory(**params)
dask_model.set_params(client=client)
assert dask_model._client == client
assert dask_model.client == client

with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_

dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client

preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client

local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_

Expand Down Expand Up @@ -606,10 +597,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None

with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
Expand Down Expand Up @@ -640,14 +629,11 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
serializer=serializer
)

assert model_from_disk._client is None
assert model_from_disk.client is None

if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None

with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
Expand All @@ -674,7 +660,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici

assert "client" not in local_model.get_params()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_

Expand All @@ -701,17 +686,14 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
)

if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2

assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk._client is None
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2
Expand Down

0 comments on commit b1e000c

Please sign in to comment.