From 7bb60b24aa0ec7755d9e763f8689a3558409e5bc Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 3 Jan 2023 17:09:30 -0800 Subject: [PATCH] Adds init_spec() to lit_nlp.api.model classes. PiperOrigin-RevId: 499350310 --- lit_nlp/api/model.py | 28 +++++++++++++++++++++++++++- lit_nlp/api/model_test.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/lit_nlp/api/model.py b/lit_nlp/api/model.py index c94e7d1b..034457d1 100644 --- a/lit_nlp/api/model.py +++ b/lit_nlp/api/model.py @@ -17,8 +17,9 @@ import inspect import itertools import multiprocessing # for ThreadPool -from typing import Iterable, Iterator, Union +from typing import Iterable, Iterator, Optional, Union +from absl import logging import attr from lit_nlp.api import dataset as lit_dataset from lit_nlp.api import types @@ -92,6 +93,31 @@ def max_minibatch_size(self) -> int: """Maximum minibatch size for this model.""" return 1 + def init_spec(self) -> Optional[Spec]: + """Attempts to infer a Spec describing a Model's constructor parameters. + + The Model base class attempts to infer a Spec for the constructor using + `lit_nlp.api.types.infer_spec_for_func()`. + + If successful, this function will return a `dict[str, LitType]`. If + unsucessful (i.e., the inferencer raises a `TypeError` because it encounters + a parameter that it not supported by `infer_spec_for_func()`), this function + will return None, log a warning describing where and how the inferencing + failed, and LIT users **will not** be able to load new instances of this + model from the UI. + + Returns: + A Spec representation of the Model's constructor, or None if a Spec could + not be inferred. + """ + try: + spec = types.infer_spec_for_func(self.__init__) + except TypeError as e: + spec = None + logging.warning("Unable to infer init spec for model '%s'. %s", + self.__class__.__name__, str(e), exc_info=True) + return spec + def is_compatible_with_dataset(self, dataset: lit_dataset.Dataset) -> bool: """Return true if this model is compatible with the dataset spec.""" dataset_spec = dataset.spec() diff --git a/lit_nlp/api/model_test.py b/lit_nlp/api/model_test.py index 329ace48..054fddf2 100644 --- a/lit_nlp/api/model_test.py +++ b/lit_nlp/api/model_test.py @@ -65,6 +65,22 @@ def predict_minibatch(self, inputs: list[model.JsonDict], **kw): return map(lambda x: {"scores": x["value"]}, inputs) +class TestSavedModel(model.Model): + """A dummy model imitating saved model semantics for testing init_spec().""" + + def __init__(self, path: str, *args, compute_embs: bool = False, **kwargs): + pass + + def input_spec(self) -> types.Spec: + return {} + + def output_spec(self) -> types.Spec: + return {} + + def predict_minibatch(self, *args, **kwargs) -> list[types.JsonDict]: + return [] + + class ModelTest(parameterized.TestCase): @parameterized.named_parameters( @@ -181,6 +197,28 @@ def test_batched_predict(self, inputs: list[model.JsonDict], self.assertEqual(len(result), len(inputs)) self.assertEqual(test_model.count, expected_run_count) + def test_init_spec_empty(self): + mdl = TestBatchingModel() + self.assertEmpty(mdl.init_spec()) + + def test_init_spec_populated(self): + mdl = TestSavedModel("test/path") + self.assertEqual(mdl.init_spec(), { + "path": types.String(), + "compute_embs": types.Boolean(default=False, required=False), + }) + + @parameterized.named_parameters( + ("bad_args", CompatibilityTestModel({})), + # All ModelWrapper instances should return None, regardless of the model + # the instance is wrapping. + ("wrap_bad_args", model.ModelWrapper(CompatibilityTestModel({}))), + ("wrap_good_args", model.ModelWrapper(TestSavedModel("test/path"))), + ("wrap_no_args", model.ModelWrapper(TestBatchingModel())), + ) + def test_init_spec_none(self, mdl: model.Model): + self.assertIsNone(mdl.init_spec()) + if __name__ == "__main__": absltest.main()