-
-
Notifications
You must be signed in to change notification settings - Fork 259
/
Copy path_partial.py
218 lines (171 loc) · 5.9 KB
/
_partial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from __future__ import absolute_import, division, print_function
import logging
import os
import dask
import numpy as np
import sklearn.utils
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from toolz import partial
from ._compat import DASK_2022_01_0
logger = logging.getLogger(__name__)
def _partial_fit(model, x, y, kwargs=None):
kwargs = kwargs or dict()
model.partial_fit(x, y, **kwargs)
return model
def fit(
model,
x,
y,
compute=True,
shuffle_blocks=True,
random_state=None,
assume_equal_chunks=False,
**kwargs,
):
"""Fit scikit learn model against dask arrays
Model must support the ``partial_fit`` interface for online or batch
learning.
Ideally your rows are independent and identically distributed. By default,
this function will step through chunks of the arrays in random order.
Parameters
----------
model: sklearn model
Any model supporting partial_fit interface
x: dask Array
Two dimensional array, likely tall and skinny
y: dask Array
One dimensional array with same chunks as x's rows
compute : bool
Whether to compute this result
shuffle_blocks : bool
Whether to shuffle the blocks with ``random_state`` or not
random_state : int or numpy.random.RandomState
Random state to use when shuffling blocks
kwargs:
options to pass to partial_fit
Examples
--------
>>> import dask.array as da
>>> X = da.random.random((10, 3), chunks=(5, 3))
>>> y = da.random.randint(0, 2, 10, chunks=(5,))
>>> from sklearn.linear_model import SGDClassifier
>>> sgd = SGDClassifier()
>>> sgd = da.learn.fit(sgd, X, y, classes=[1, 0])
>>> sgd # doctest: +SKIP
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
fit_intercept=True, l1_ratio=0.15, learning_rate='optimal',
loss='hinge', n_iter=5, n_jobs=1, penalty='l2', power_t=0.5,
random_state=None, shuffle=False, verbose=0, warm_start=False)
This passes all of X and y through the classifier sequentially. We can use
the classifier as normal on in-memory data
>>> import numpy as np
>>> sgd.predict(np.random.random((4, 3))) # doctest: +SKIP
array([1, 0, 0, 1])
Or predict on a larger dataset
>>> z = da.random.random((400, 3), chunks=(100, 3))
>>> da.learn.predict(sgd, z) # doctest: +SKIP
dask.array<x_11, shape=(400,), chunks=((100, 100, 100, 100),), dtype=int64>
"""
nblocks, x_name = _blocks_and_name(x)
if y is not None:
y_nblocks, y_name = _blocks_and_name(y)
assert y_nblocks == nblocks
else:
y_name = ""
if not hasattr(model, "partial_fit"):
msg = "The class '{}' does not implement 'partial_fit'."
raise ValueError(msg.format(type(model)))
order = list(range(nblocks))
if shuffle_blocks:
rng = sklearn.utils.check_random_state(random_state)
rng.shuffle(order)
name = "fit-" + dask.base.tokenize(model, x, y, kwargs, order)
if hasattr(x, "chunks") and x.ndim > 1:
x_extra = (0,)
else:
x_extra = ()
dsk = {(name, -1): model}
dsk.update(
{
(name, i): (
_partial_fit,
(name, i - 1),
(x_name, order[i]) + x_extra,
(y_name, order[i]),
kwargs,
)
for i in range(nblocks)
}
)
dependencies = [x]
if y is not None:
dependencies.append(y)
new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies)
if DASK_2022_01_0:
value = Delayed((name, nblocks - 1), new_dsk, layer=name)
else:
value = Delayed((name, nblocks - 1), new_dsk)
if compute:
return value.compute()
else:
return value
def _blocks_and_name(obj):
if hasattr(obj, "chunks"):
nblocks = len(obj.chunks[0])
name = obj.name
elif hasattr(obj, "npartitions"):
# dataframe, bag
nblocks = obj.npartitions
if hasattr(obj, "_name"):
# dataframe
name = obj._name
else:
# bag
name = obj.name
return nblocks, name
def _predict(model, x):
return model.predict(x)
def predict(model, x):
"""Predict with a scikit learn model
Parameters
----------
model : scikit learn classifier
x : dask Array
See docstring for ``da.learn.fit``
"""
func = partial(_predict, model)
if getattr(model, "feature_names_in_", None) is not None:
meta = model.predict(x._meta_nonempty)
return x.map_partitions(func, meta=meta)
else:
if len(x.chunks[1]) > 1:
x = x.rechunk(chunks=(x.chunks[0], sum(x.chunks[1])))
xx = np.zeros((1, x.shape[1]), dtype=x.dtype)
meta = model.predict(xx)
if meta.ndim > 1:
chunks = (x.chunks[0], (1,))
drop_axis = None
else:
chunks = (x.chunks[0],)
drop_axis = 1
return x.map_blocks(func, chunks=chunks, meta=meta, drop_axis=drop_axis)
def _copy_partial_doc(cls):
for base in cls.mro():
if base.__module__.startswith("sklearn"):
break
lines = base.__doc__.split(os.linesep)
header, rest = lines[0], lines[1:]
insert = """
.. deprecated:: 0.6.0
Use the :class:`dask_ml.wrappers.Incremental` meta-estimator instead.
This class wraps scikit-learn's {classname}. When a dask-array is passed
to our ``fit`` method, the array is passed block-wise to the scikit-learn
class' ``partial_fit`` method. This will allow you to fit the estimator
on larger-than memory datasets sequentially (block-wise), but without an
parallelism, or any ability to distribute across a cluster.""".format(
classname=base.__name__
)
doc = "\n".join([header + insert] + rest)
cls.__doc__ = doc
return cls