Skip to content

Commit

Permalink
Adds init_spec() to lit_nlp.api.model classes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 499350310
  • Loading branch information
RyanMullins authored and LIT team committed Jan 4, 2023
1 parent d624562 commit 7bb60b2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
28 changes: 27 additions & 1 deletion lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import inspect
import itertools
import multiprocessing # for ThreadPool
from typing import Iterable, Iterator, Union
from typing import Iterable, Iterator, Optional, Union

from absl import logging
import attr
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types
Expand Down Expand Up @@ -92,6 +93,31 @@ def max_minibatch_size(self) -> int:
"""Maximum minibatch size for this model."""
return 1

def init_spec(self) -> Optional[Spec]:
"""Attempts to infer a Spec describing a Model's constructor parameters.
The Model base class attempts to infer a Spec for the constructor using
`lit_nlp.api.types.infer_spec_for_func()`.
If successful, this function will return a `dict[str, LitType]`. If
unsucessful (i.e., the inferencer raises a `TypeError` because it encounters
a parameter that it not supported by `infer_spec_for_func()`), this function
will return None, log a warning describing where and how the inferencing
failed, and LIT users **will not** be able to load new instances of this
model from the UI.
Returns:
A Spec representation of the Model's constructor, or None if a Spec could
not be inferred.
"""
try:
spec = types.infer_spec_for_func(self.__init__)
except TypeError as e:
spec = None
logging.warning("Unable to infer init spec for model '%s'. %s",
self.__class__.__name__, str(e), exc_info=True)
return spec

def is_compatible_with_dataset(self, dataset: lit_dataset.Dataset) -> bool:
"""Return true if this model is compatible with the dataset spec."""
dataset_spec = dataset.spec()
Expand Down
38 changes: 38 additions & 0 deletions lit_nlp/api/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ def predict_minibatch(self, inputs: list[model.JsonDict], **kw):
return map(lambda x: {"scores": x["value"]}, inputs)


class TestSavedModel(model.Model):
"""A dummy model imitating saved model semantics for testing init_spec()."""

def __init__(self, path: str, *args, compute_embs: bool = False, **kwargs):
pass

def input_spec(self) -> types.Spec:
return {}

def output_spec(self) -> types.Spec:
return {}

def predict_minibatch(self, *args, **kwargs) -> list[types.JsonDict]:
return []


class ModelTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -181,6 +197,28 @@ def test_batched_predict(self, inputs: list[model.JsonDict],
self.assertEqual(len(result), len(inputs))
self.assertEqual(test_model.count, expected_run_count)

def test_init_spec_empty(self):
mdl = TestBatchingModel()
self.assertEmpty(mdl.init_spec())

def test_init_spec_populated(self):
mdl = TestSavedModel("test/path")
self.assertEqual(mdl.init_spec(), {
"path": types.String(),
"compute_embs": types.Boolean(default=False, required=False),
})

@parameterized.named_parameters(
("bad_args", CompatibilityTestModel({})),
# All ModelWrapper instances should return None, regardless of the model
# the instance is wrapping.
("wrap_bad_args", model.ModelWrapper(CompatibilityTestModel({}))),
("wrap_good_args", model.ModelWrapper(TestSavedModel("test/path"))),
("wrap_no_args", model.ModelWrapper(TestBatchingModel())),
)
def test_init_spec_none(self, mdl: model.Model):
self.assertIsNone(mdl.init_spec())


if __name__ == "__main__":
absltest.main()

0 comments on commit 7bb60b2

Please sign in to comment.