-
Notifications
You must be signed in to change notification settings - Fork 356
/
Copy pathmodel.py
396 lines (315 loc) · 12.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base classes for LIT models."""
import abc
from collections.abc import Iterable, Iterator, Mapping
import inspect
import itertools
import multiprocessing.pool # for ThreadPool
from typing import Optional, Union
from absl import logging
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types
from lit_nlp.lib import utils
import numpy as np
JsonDict = types.JsonDict
Spec = types.Spec
def maybe_copy_np(arr):
"""Decide if we should make a copy of an array in order to release memory.
NumPy arrays may be views into other array objects, by which a small array can
maintain a persistent pointer to a large block of memory that prevents it from
being garbage collected. This can quickly lead to memory issues as such
blocks accumulate during inference.
Args:
arr: a NumPy array
Returns:
arr, or a copy of arr
"""
if not isinstance(arr, np.ndarray):
return arr
# If this is not a view of another array.
if arr.base is None:
return arr
# Tensorflow provides a bridge to share memory between tensorflow and numpy
# arrays. This looks like a view into an array but the base is a
# tensorflow_wrapper not an array, so the view heuristics below don't work. We
# can check for this case by checking is arr.base has the ndim attribute.
# https://github.com/tensorflow/tensorflow/blob/6ed79e8429730c33dc894175da7a1849a8e3e57f/tensorflow/python/lib/core/ndarray_tensor_bridge.cc#L90
if not hasattr(arr.base, 'ndim'):
return np.copy(arr)
# Heuristic to check if we should 'detach' this array from the parent blob.
# We want to know if this array is a view that might leak memory.
# The simplest check is if arr.base is larger than arr, but we don't want to
# make unnecessary copies when this is just due to slicing along a batch,
# because the other rows are likely still in use.
# TODO(lit-dev): keep an eye on this, if we continue to have memory issues
# we can make copies more aggressively.
if arr.base.ndim > 1 and np.prod(arr.base.shape[1:]) > np.prod(arr.shape):
return np.copy(arr)
# If only a batch slice, reshape, or otherwise.
return arr
def scrub_numpy_refs(output: JsonDict) -> JsonDict:
"""Scrub numpy pointers; see maybe_copy_np() and Model.predict()."""
return {k: maybe_copy_np(v) for k, v in output.items()}
class Model(metaclass=abc.ABCMeta):
"""Base class for LIT models."""
def description(self) -> str:
"""Return a human-readable description of this component.
Defaults to class docstring, but subclass may override this to be
instance-dependent - for example, including the path from which the model
was loaded.
Returns:
(string) A human-readable description for display in the UI.
"""
return inspect.getdoc(self) or ''
def __str__(self) -> str:
classname = self.__class__.__module__ + '.' + self.__class__.__qualname__
indented_description = ' ' + self.description().replace('\n', '\n ')
return f'{classname}(...):\n{indented_description}'
def _repr_pretty_(self, p, cycle):
"""Pretty-printing for IPython environments, both notebooks and repl."""
if not cycle:
p.text(str(self))
else:
p.text('...')
@classmethod
def init_spec(cls) -> Optional[Spec]:
"""Attempts to infer a Spec describing a Model's constructor parameters.
The Model base class attempts to infer a Spec for the constructor using
`lit_nlp.api.types.infer_spec_for_func()`.
If successful, this function will return a `dict[str, LitType]`. If
unsucessful (i.e., the inferencer raises a `TypeError` because it encounters
a parameter that it not supported by `infer_spec_for_func()`), this function
will return None, log a warning describing where and how the inferencing
failed, and LIT users **will not** be able to load new instances of this
model from the UI.
Returns:
A Spec representation of the Model's constructor, or None if a Spec could
not be inferred.
"""
try:
spec = types.infer_spec_for_func(cls.__init__)
except TypeError as e:
spec = None
logging.warning(
"Unable to infer init spec for model '%s'. %s", cls.__name__, str(e)
)
return spec
def is_compatible_with_dataset(self, dataset: lit_dataset.Dataset) -> bool:
"""Return true if this model is compatible with the dataset spec."""
dataset_spec = dataset.spec()
for key, field_spec in self.input_spec().items():
if key in dataset_spec:
# If the field is in the dataset, make sure it's compatible.
if not dataset_spec[key].is_compatible(field_spec):
return False
else:
# If the field isn't in the dataset, only allow if the model marks as
# optional.
if field_spec.required:
return False
return True
@property
def supports_concurrent_predictions(self):
"""Indcates support for multiple concurrent predict calls across threads.
Defaults to false.
Returns:
(bool) True if the model can handle multiple concurrent calls to its
`predict` method.
"""
return False
def load(self, path: str):
"""Load and return a new instance of this model loaded from a new path.
By default this method does nothing. Models can override this method in
order to allow dynamic model loading in LIT through the UI. Models
overriding this method should use the provided path string and create and
return a new instance of its model class.
Args:
path: The path to the persisted model information, used in model's
construction.
Returns:
(Model) A model loaded with information from the provided path.
"""
del path
raise NotImplementedError('Model has no load method defined for dynamic '
'loading')
@abc.abstractmethod
def input_spec(self) -> types.Spec:
"""Return a spec describing model inputs."""
return
@abc.abstractmethod
def output_spec(self) -> types.Spec:
"""Return a spec describing model outputs."""
return
def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
"""Return the full vocabulary and embedding table.
Implementing this is optional, but needed for some techniques such as
HotFlip which use the embedding table to search over candidate words.
Returns:
(<string>[vocab_size], <float32>[vocab_size, emb_dim])
"""
raise NotImplementedError('get_embedding_table() not implemented for ' +
self.__class__.__name__)
@abc.abstractmethod
def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]:
"""Run prediction on a list of inputs and return the outputs."""
pass
ModelMap = Mapping[str, Model]
class ModelWrapper(Model):
"""Wrapper for a LIT model.
This class acts as an identity function, with pass-through implementations of
the Model API. Subclasses of this can implement only those methods that need
to be modified.
"""
def __init__(self, model: Model):
self._wrapped = model
@property
def wrapped(self):
"""Access the wrapped model."""
return self._wrapped
def description(self) -> str:
return self.wrapped.description()
@property
def supports_concurrent_predictions(self):
return self.wrapped.supports_concurrent_predictions
def predict(
self, inputs: Iterable[JsonDict], *args, **kw
) -> Iterable[JsonDict]:
return self.wrapped.predict(inputs, *args, **kw)
def load(self, path: str):
"""Load a new model and wrap it with this class."""
new_model = self.wrapped.load(path)
return self.__class__(new_model)
def input_spec(self) -> types.Spec:
return self.wrapped.input_spec()
def output_spec(self) -> types.Spec:
return self.wrapped.output_spec()
##
# Special methods
def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
return self.wrapped.get_embedding_table()
class BatchedModel(Model):
"""Generic base class for the batched model.
Subclass needs to implement predict_minibatch() and optionally
max_minibatch_size().
"""
def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model."""
return 1
@property
def supports_concurrent_predictions(self):
return False
@abc.abstractmethod
def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
"""Run prediction on a batch of inputs.
Args:
inputs: sequence of inputs, following model.input_spec()
Returns:
list of outputs, following model.output_spec()
"""
pass
def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]:
"""Run prediction on a dataset.
This uses minibatch inference for efficiency, but yields per-example output.
This will also copy some NumPy arrays if they look like slices of a larger
tensor. This adds some overhead, but reduces memory leaks by allowing the
source tensor (which may be a large padded matrix) to be garbage collected.
Args:
inputs: iterable of input dicts
**kw: additional kwargs passed to predict_minibatch()
Returns:
model outputs, for each input
"""
results = self.batched_predict(inputs, **kw)
results = (scrub_numpy_refs(res) for res in results)
return results
def batched_predict(
self, inputs: Iterable[JsonDict], **kw
) -> Iterator[JsonDict]:
"""Internal helper to predict using minibatches."""
minibatch_size = self.max_minibatch_size(**kw)
minibatch = []
for ex in inputs:
if len(minibatch) < minibatch_size:
minibatch.append(ex)
if len(minibatch) >= minibatch_size:
yield from self.predict_minibatch(minibatch, **kw)
minibatch = []
if len(minibatch) > 0: # pylint: disable=g-explicit-length-test
yield from self.predict_minibatch(minibatch, **kw)
class BatchedRemoteModel(Model):
"""Generic base class for remotely-hosted models.
Implements concurrent request batching; subclass need only implement
predict_minibatch() and max_minibatch_size().
If subclass overrides __init__, it should be sure to call super().__init__()
to set up the threadpool.
"""
def __init__(self,
max_concurrent_requests: int = 4,
max_qps: Union[int, float] = 25):
# Use a local thread pool for concurrent requests, so we can keep the server
# busy during network transit time and local pre/post-processing.
self._max_qps = max_qps
self._pool = multiprocessing.pool.ThreadPool(max_concurrent_requests)
def predict(
self,
inputs: Iterable[JsonDict],
*unused_args,
parallel=True,
**unused_kwargs
) -> Iterator[JsonDict]:
batches = utils.batch_iterator(
inputs, max_batch_size=self.max_minibatch_size())
batches = utils.rate_limit(batches, self._max_qps)
if parallel:
pred_batches = self._pool.imap(self.predict_minibatch, batches)
else: # don't use the threadpool; useful for debugging
pred_batches = map(self.predict_minibatch, batches)
return itertools.chain.from_iterable(pred_batches)
def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model. Subclass can override this."""
return 1
@property
def supports_concurrent_predictions(self):
"""Remote models can handle concurrent predictions by default."""
return True
@abc.abstractmethod
def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
"""Run prediction on a batch of inputs.
Subclass should implement this.
Args:
inputs: sequence of inputs, following model.input_spec()
Returns:
list of outputs, following model.output_spec()
"""
return
class ProjectorModel(BatchedModel, metaclass=abc.ABCMeta):
"""LIT Model API for dimensionality reduction."""
##
# Training methods
@abc.abstractmethod
def fit_transform(self, inputs: Iterable[JsonDict]) -> list[JsonDict]:
"""For internal use by SciKit Learn-based models."""
pass
##
# LIT model API
def input_spec(self):
# 'x' denotes input features
return {'x': types.Embeddings()}
def output_spec(self):
# 'z' denotes projected embeddings
return {'z': types.Embeddings()}
def max_minibatch_size(self, **unused_kw):
return 1000