diff --git a/lit_nlp/api/model.py b/lit_nlp/api/model.py index 8aa610ae..6acc4b3c 100644 --- a/lit_nlp/api/model.py +++ b/lit_nlp/api/model.py @@ -82,10 +82,6 @@ def description(self) -> str: """ return inspect.getdoc(self) or '' - def max_minibatch_size(self) -> int: - """Maximum minibatch size for this model.""" - return 1 - @classmethod def init_spec(cls) -> Optional[Spec]: """Attempts to infer a Spec describing a Model's constructor parameters. @@ -137,22 +133,10 @@ def supports_concurrent_predictions(self): Returns: (bool) True if the model can handle multiple concurrent calls to its - `predict_minibatch` method. + `predict` method. """ return False - @abc.abstractmethod - def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]: - """Run prediction on a batch of inputs. - - Args: - inputs: sequence of inputs, following model.input_spec() - - Returns: - list of outputs, following model.output_spec() - """ - return - def load(self, path: str): """Load and return a new instance of this model loaded from a new path. @@ -194,41 +178,10 @@ def get_embedding_table(self) -> tuple[list[str], np.ndarray]: raise NotImplementedError('get_embedding_table() not implemented for ' + self.__class__.__name__) - ## - # Concrete implementations of common functions. + @abc.abstractmethod def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]: - """Run prediction on a dataset. - - This uses minibatch inference for efficiency, but yields per-example output. - - This will also copy some NumPy arrays if they look like slices of a larger - tensor. This adds some overhead, but reduces memory leaks by allowing the - source tensor (which may be a large padded matrix) to be garbage collected. - - Args: - inputs: iterable of input dicts - **kw: additional kwargs passed to predict_minibatch() - - Returns: - model outputs, for each input - """ - results = self._batched_predict(inputs, **kw) - results = (scrub_numpy_refs(res) for res in results) - return results - - def _batched_predict(self, inputs: Iterable[JsonDict], - **kw) -> Iterator[JsonDict]: - """Internal helper to predict using minibatches.""" - minibatch_size = self.max_minibatch_size(**kw) - minibatch = [] - for ex in inputs: - if len(minibatch) < minibatch_size: - minibatch.append(ex) - if len(minibatch) >= minibatch_size: - yield from self.predict_minibatch(minibatch, **kw) - minibatch = [] - if len(minibatch) > 0: # pylint: disable=g-explicit-length-test - yield from self.predict_minibatch(minibatch, **kw) + """Run prediction on a list of inputs and return the outputs.""" + pass class ModelWrapper(Model): @@ -250,16 +203,10 @@ def wrapped(self): def description(self) -> str: return self.wrapped.description() - def max_minibatch_size(self) -> int: - return self.wrapped.max_minibatch_size() - @property def supports_concurrent_predictions(self): return self.wrapped.supports_concurrent_predictions - def predict_minibatch(self, inputs: list[JsonDict], **kw) -> list[JsonDict]: - return self.wrapped.predict_minibatch(inputs, **kw) - def predict( self, inputs: Iterable[JsonDict], *args, **kw ) -> Iterable[JsonDict]: @@ -285,10 +232,64 @@ def get_embedding_table(self) -> tuple[list[str], np.ndarray]: class BatchedModel(Model): """Generic base class for the batched model. - Currently this is a no-op pass-through of Model class and will be updated - after moving users of Model class over. + Subclass needs to implement predict_minibatch() and optionally + max_minibatch_size(). """ - pass + + def max_minibatch_size(self) -> int: + """Maximum minibatch size for this model.""" + return 1 + + @property + def supports_concurrent_predictions(self): + return False + + @abc.abstractmethod + def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]: + """Run prediction on a batch of inputs. + + Args: + inputs: sequence of inputs, following model.input_spec() + + Returns: + list of outputs, following model.output_spec() + """ + pass + + def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]: + """Run prediction on a dataset. + + This uses minibatch inference for efficiency, but yields per-example output. + + This will also copy some NumPy arrays if they look like slices of a larger + tensor. This adds some overhead, but reduces memory leaks by allowing the + source tensor (which may be a large padded matrix) to be garbage collected. + + Args: + inputs: iterable of input dicts + **kw: additional kwargs passed to predict_minibatch() + + Returns: + model outputs, for each input + """ + results = self.batched_predict(inputs, **kw) + results = (scrub_numpy_refs(res) for res in results) + return results + + def batched_predict( + self, inputs: Iterable[JsonDict], **kw + ) -> Iterator[JsonDict]: + """Internal helper to predict using minibatches.""" + minibatch_size = self.max_minibatch_size(**kw) + minibatch = [] + for ex in inputs: + if len(minibatch) < minibatch_size: + minibatch.append(ex) + if len(minibatch) >= minibatch_size: + yield from self.predict_minibatch(minibatch, **kw) + minibatch = [] + if len(minibatch) > 0: # pylint: disable=g-explicit-length-test + yield from self.predict_minibatch(minibatch, **kw) class BatchedRemoteModel(Model): diff --git a/lit_nlp/api/model_test.py b/lit_nlp/api/model_test.py index 71cd7e13..dd831220 100644 --- a/lit_nlp/api/model_test.py +++ b/lit_nlp/api/model_test.py @@ -33,8 +33,7 @@ def input_spec(self) -> types.Spec: def output_spec(self) -> types.Spec: return {} - def predict_minibatch(self, - inputs: list[model.JsonDict]) -> list[model.JsonDict]: + def predict(self, inputs: list[model.JsonDict]) -> list[model.JsonDict]: return [] @@ -77,7 +76,7 @@ def input_spec(self) -> types.Spec: def output_spec(self) -> types.Spec: return {} - def predict_minibatch(self, *args, **kwargs) -> list[types.JsonDict]: + def predict(self, *args, **kwargs) -> list[types.JsonDict]: return [] diff --git a/lit_nlp/components/hotflip_test.py b/lit_nlp/components/hotflip_test.py index 5141f5af..da7fcb4a 100644 --- a/lit_nlp/components/hotflip_test.py +++ b/lit_nlp/components/hotflip_test.py @@ -58,8 +58,9 @@ def output_spec(self) -> dict[str, lit_types.LitType]: def get_embedding_table(self): return ([], np.ndarray([])) - def predict_minibatch( - self, inputs: list[lit_model.JsonDict]) -> list[lit_model.JsonDict]: + def predict( + self, inputs: list[lit_model.JsonDict] + ) -> list[lit_model.JsonDict]: pass @@ -108,8 +109,9 @@ def output_spec(self) -> dict[str, lit_types.LitType]: def get_embedding_table(self): return ([], np.ndarray([])) - def predict_minibatch( - self, inputs: list[lit_model.JsonDict]) -> list[lit_model.JsonDict]: + def predict( + self, inputs: list[lit_model.JsonDict] + ) -> list[lit_model.JsonDict]: pass diff --git a/lit_nlp/components/shap_explainer_test.py b/lit_nlp/components/shap_explainer_test.py index 4640d14f..f8fe0447 100644 --- a/lit_nlp/components/shap_explainer_test.py +++ b/lit_nlp/components/shap_explainer_test.py @@ -48,7 +48,7 @@ def input_spec(self) -> lit_types.Spec: def output_spec(self) -> lit_types.Spec: return {} - def predict_minibatch(self, inputs, **kw): + def predict(self, inputs, **kw): return None diff --git a/lit_nlp/examples/models/t5.py b/lit_nlp/examples/models/t5.py index 9970429b..94c075b8 100644 --- a/lit_nlp/examples/models/t5.py +++ b/lit_nlp/examples/models/t5.py @@ -423,13 +423,6 @@ def preprocess(self, ex: JsonDict) -> JsonDict: def description(self) -> str: return "T5 for machine translation\n" + self.wrapped.description() - # TODO(b/170662608): remove these after batching API is cleaned up. - def max_minibatch_size(self) -> int: - raise NotImplementedError("Use predict() instead.") - - def predict_minibatch(self, inputs): - raise NotImplementedError("Use predict() instead.") - def predict(self, inputs): """Predict on a single minibatch of examples.""" model_inputs = (self.preprocess(ex) for ex in inputs) @@ -479,13 +472,6 @@ def preprocess(self, ex: JsonDict) -> JsonDict: def description(self) -> str: return "T5 for summarization\n" + self.wrapped.description() - # TODO(b/170662608): remove these after batching API is cleaned up. - def max_minibatch_size(self) -> int: - raise NotImplementedError("Use predict() instead.") - - def predict_minibatch(self, inputs): - raise NotImplementedError("Use predict() instead.") - def predict(self, inputs): """Predict on a single minibatch of examples.""" inputs = list(inputs) # needs to be referenced below, so keep full list diff --git a/lit_nlp/lib/caching.py b/lit_nlp/lib/caching.py index 88daac97..2dccdc24 100644 --- a/lit_nlp/lib/caching.py +++ b/lit_nlp/lib/caching.py @@ -278,12 +278,6 @@ def fit_transform(self, inputs: Iterable[JsonDict]): self._cache.put(output, cache_key) return outputs - # TODO(b/170662608) Remove once batching logic changes are done. - def predict_minibatch(self, *args, **kw): - raise RuntimeError( - "This method should be inaccessible as it bypasses the cache. Please" - " use CachingModelWrapper.predict().") - def predict(self, inputs: Iterable[JsonDict], progress_indicator: Optional[ProgressIndicator] = lambda x: x,