Skip to content

Commit

Permalink
Make existing LIT model classes inherit from LIT's BatchedModel ins…
Browse files Browse the repository at this point in the history
…tead of `Model`.

The goal is to move the batching logic out of LIT's `Model` to `BatchedModel`.
Currently LIT `BatchedModel` just inherits from `Model` with [no additional functionality](http://google3/third_party/py/lit_nlp/api/model.py;l=285-291;rcl=552893648). We plan to
* migrate all use cases that require the batching logic to `BatchedModel` (current change),
* move the relevant batching functionalities from `Model` to `BatchedModel` (future changes in LIT internals that don't require user involvement).

PiperOrigin-RevId: 553497520
  • Loading branch information
bdu91 authored and LIT team committed Aug 3, 2023
1 parent 774cdbc commit 0146d5f
Show file tree
Hide file tree
Showing 22 changed files with 33 additions and 33 deletions.
2 changes: 1 addition & 1 deletion lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
return


class ProjectorModel(Model, metaclass=abc.ABCMeta):
class ProjectorModel(BatchedModel, metaclass=abc.ABCMeta):
"""LIT Model API for dimensionality reduction."""

##
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/api/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def predict_minibatch(self,
return []


class _BatchingTestModel(model.Model):
class _BatchingTestModel(model.BatchedModel):
"""A model for testing batched predictions with a minibatch size of 3."""

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/curves_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
COLORS = ['red', 'green', 'blue']

_Curve = list[tuple[float, float]]
_Model = lit_model.Model
_Model = lit_model.BatchedModel


class _DataEntryForTesting(NamedTuple):
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/components/image_gradient_maps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
JsonDict = lit_types.JsonDict


class ClassificationTestModel(lit_model.Model):
class ClassificationTestModel(lit_model.BatchedModel):

LABELS = ['Dummy', 'Cat', 'Dog']
GRADIENT_SHAPE = (60, 40, 3)
Expand Down Expand Up @@ -62,7 +62,7 @@ def output_spec(self):
}


class RegressionTestModel(lit_model.Model):
class RegressionTestModel(lit_model.BatchedModel):
"""A test model for testing the regression case."""

GRADIENT_SHAPE = (40, 20, 3)
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/components/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
LitType = types.LitType


class _GenTextTestModel(lit_model.Model):
class _GenTextTestModel(lit_model.BatchedModel):

def input_spec(self) -> types.Spec:
return {'input': types.TextSegment()}
Expand All @@ -40,7 +40,7 @@ def predict_minibatch(self,
return [{'output': 'test_output'}] * len(inputs)


class _GenTextCandidatesTestModel(lit_model.Model):
class _GenTextCandidatesTestModel(lit_model.BatchedModel):

def input_spec(self) -> types.Spec:
return {
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/components/minimal_targeted_counterfactuals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def examples(self) -> List[lit_types.JsonDict]:
]


class ClassificationTestModel(lit_model.Model):
class ClassificationTestModel(lit_model.BatchedModel):
"""A test model for testing tabular hot-flips on classification tasks."""

def __init__(self, dataset: lit_dataset.Dataset) -> None:
Expand Down Expand Up @@ -168,7 +168,7 @@ def examples(self) -> List[lit_types.JsonDict]:
]


class RegressionTestModel(lit_model.Model):
class RegressionTestModel(lit_model.BatchedModel):
"""A test model for testing tabular hot-flips on regression tasks."""

def max_minibatch_size(self, **unused) -> int:
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/nearest_neighbors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
JsonDict = lit_types.JsonDict


class TestModelNearestNeighbors(lit_model.Model):
class TestModelNearestNeighbors(lit_model.BatchedModel):
"""Implements lit.Model interface for nearest neighbors.
Returns the same output for every input.
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/components/pdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
JsonDict = lit_types.JsonDict


class TestRegressionPdp(lit_model.Model):
class TestRegressionPdp(lit_model.BatchedModel):

def input_spec(self):
return {'num': lit_types.Scalar(),
Expand All @@ -42,7 +42,7 @@ def predict_minibatch(self, inputs: List[JsonDict], **kw):
for i in inputs]


class TestClassificationPdp(lit_model.Model):
class TestClassificationPdp(lit_model.BatchedModel):

def input_spec(self):
return {'num': lit_types.Scalar(),
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/remote_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def query_lit_server(url: Text,
return serialize.from_json(six.ensure_text(response_bytes))


class RemoteModel(lit_model.Model):
class RemoteModel(lit_model.BatchedModel):
"""LIT model backed by a remote LIT server."""

def __init__(self, url: Text, name: Text, max_minibatch_size: int = 256):
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/static_preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
JsonDict = lit_types.JsonDict


class StaticPredictions(lit_model.Model):
class StaticPredictions(lit_model.BatchedModel):
"""Implements lit.Model interface for a set of pre-computed predictions."""

def key_fn(self, example: JsonDict) -> str:
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/tcav_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_TEST_VOCAB = ['0', '1']


class VariableOutputSpecModel(lit_model.Model):
class VariableOutputSpecModel(lit_model.BatchedModel):
"""A dummy model used for testing interpreter compatibility."""

def __init__(self, output_spec: lit_types.Spec):
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/tfx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _inputs_to_serialized_example(input_dict: lit_types.JsonDict):
return result.SerializeToString()


class TFXModel(lit_model.Model):
class TFXModel(lit_model.BatchedModel):
"""Wrapper for querying a TFX-generated SavedModel."""

def __init__(self, config: TFXModelConfig):
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/examples/coref/edge_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, examples):
@classmethod
def build(cls,
inputs: List[JsonDict],
encoder: lit_model.Model,
encoder: lit_model.BatchedModel,
edge_field: str,
embs_field: str,
offset_field: str,
Expand Down Expand Up @@ -140,7 +140,7 @@ def spec(self):
}


class SingleEdgePredictor(lit_model.Model):
class SingleEdgePredictor(lit_model.BatchedModel):
"""Coref model for a single edge. Compatible with EdgeFeaturesDataset."""

def build_model(self, input_dim: int, hidden_dim: int = 256):
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/examples/coref/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import transformers


class BertEncoderWithOffsets(lit_model.Model):
class BertEncoderWithOffsets(lit_model.BatchedModel):
"""BERT encoder for pre-tokenized text."""

@property
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/examples/coref/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
JsonDict = lit_types.JsonDict


class FrozenEncoderCoref(lit_model.Model):
class FrozenEncoderCoref(lit_model.BatchedModel):
"""Frozen-encoder coreference model."""

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/examples/models/glue_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def init_spec(cls) -> lit_types.Spec:
}


class GlueModel(lit_model.Model):
class GlueModel(lit_model.BatchedModel):
"""GLUE benchmark model, using Keras/TF2 and Huggingface Transformers.
This is a general-purpose classification or regression model. It works for
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/examples/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
IMAGE_SHAPE = (224, 224, 3)


class MobileNet(model.Model):
class MobileNet(model.BatchedModel):
"""MobileNet model trained on ImageNet dataset."""

def __init__(self, name='mobilenet_v2') -> None:
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/examples/models/penguin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
_VOCABS = penguin_data.VOCABS


class PenguinModel(lit_model.Model):
class PenguinModel(lit_model.BatchedModel):
"""TensorFlow Keras model for penguin classification."""

def __init__(self, path: str):
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/examples/models/pretrained_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import transformers


class BertMLM(lit_model.Model):
class BertMLM(lit_model.BatchedModel):
"""BERT masked LM using Huggingface Transformers and TensorFlow 2."""

MASK_TOKEN = "[MASK]"
Expand Down Expand Up @@ -137,7 +137,7 @@ def output_spec(self):
}


class GPT2LanguageModel(lit_model.Model):
class GPT2LanguageModel(lit_model.BatchedModel):
"""Wrapper for a Huggingface Transformers GPT-2 model.
This class loads a tokenizer and model using the Huggingface library and
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/examples/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def validate_t5_model(model: lit_model.Model) -> lit_model.Model:
return model


class T5SavedModel(lit_model.Model):
class T5SavedModel(lit_model.BatchedModel):
"""T5 from a TensorFlow SavedModel, for black-box access.
To create a SavedModel from a regular T5 checkpoint, see
Expand Down Expand Up @@ -150,7 +150,7 @@ def output_spec(self):
return {"output_text": lit_types.GeneratedText(parent="target_text")}


class T5HFModel(lit_model.Model):
class T5HFModel(lit_model.BatchedModel):
"""T5 using HuggingFace Transformers and Keras.
This version supports embeddings, attention, and force-decoding of the target
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/examples/simple_tf2_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _from_pretrained(cls, *args, **kw):
return cls.from_pretrained(*args, from_pt=True, **kw)


class SimpleSentimentModel(lit_model.Model):
class SimpleSentimentModel(lit_model.BatchedModel):
"""Simple sentiment analysis model."""

LABELS = ["0", "1"] # negative, positive
Expand All @@ -95,7 +95,7 @@ def __init__(self, model_name_or_path):
##
# LIT API implementation
def max_minibatch_size(self):
# This tells lit_model.Model.predict() how to batch inputs to
# This tells lit_model.BatchedModel.predict() how to batch inputs to
# predict_minibatch().
# Alternately, you can just override predict() and handle batching yourself.
return 32
Expand Down
8 changes: 4 additions & 4 deletions lit_nlp/lib/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
JsonDict = lit_types.JsonDict


class RegressionModelForTesting(lit_model.Model):
class RegressionModelForTesting(lit_model.BatchedModel):
"""Implements lit.Model interface for testing.
This class allows flexible input spec to allow different testing scenarios.
Expand Down Expand Up @@ -67,7 +67,7 @@ def predict(self, inputs: Iterable[JsonDict], *args,
return map(lambda x: {'scores': 0.0}, inputs)


class IdentityRegressionModelForTesting(lit_model.Model):
class IdentityRegressionModelForTesting(lit_model.BatchedModel):
"""Implements lit.Model interface for testing.
This class reflects the input in the prediction for simple testing.
Expand Down Expand Up @@ -107,7 +107,7 @@ def count(self):
return self._count


class ClassificationModelForTesting(lit_model.Model):
class ClassificationModelForTesting(lit_model.BatchedModel):
"""Implements lit.Model interface for testing classification models.
Returns the same output for every input.
Expand Down Expand Up @@ -177,7 +177,7 @@ def assert_deep_almost_equal(testcase, result, actual, places=4):
assert_deep_almost_equal(testcase, result[key], actual[key])


class CustomOutputModelForTesting(lit_model.Model):
class CustomOutputModelForTesting(lit_model.BatchedModel):
"""Implements lit.Model interface for testing.
This class allows user-specified outputs for testing return values.
Expand Down

0 comments on commit 0146d5f

Please sign in to comment.