Skip to content

Commit

Permalink
Optional --strict_cache_id_validation mode to verify that example IDs…
Browse files Browse the repository at this point in the history
… match their contents.

This should help catch bugs related to _id fields not getting reset.

PiperOrigin-RevId: 553271307
  • Loading branch information
iftenney authored and LIT team committed Aug 2, 2023
1 parent 1e5df5f commit 774cdbc
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
14 changes: 11 additions & 3 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def _create_model(self,

new_model = model_cls(**config)
self._models[new_name] = caching.CachingModelWrapper(
new_model, new_name, cache_dir=self._data_dir
new_model, new_name, **self._caching_model_wrapper_kw
)
empty_dataset = lit_dataset.NoneDataset(self._models)
self._datasets[_EMPTY_DATASET_KEY] = lit_dataset.IndexedDataset(
Expand Down Expand Up @@ -875,6 +875,7 @@ def __init__(
validate: Optional[flag_helpers.ValidationMode] = None,
report_all: bool = False,
enforce_dataset_fields_required: bool = False,
strict_cache_id_validation: bool = False,
):
if client_root is None:
raise ValueError('client_root must be set on application')
Expand All @@ -890,6 +891,12 @@ def __init__(
if data_dir and not os.path.isdir(data_dir):
os.mkdir(data_dir)

self._caching_model_wrapper_kw = dict(
cache_dir=self._data_dir,
strict_id_validation=strict_cache_id_validation,
id_hash_fn=caching.input_hash,
)

# TODO(lit-dev): override layouts instead of merging, to allow clients
# to opt-out of the default bundled layouts. This will require updating
# client code to manually merge when this is the desired behavior.
Expand All @@ -903,8 +910,9 @@ def __init__(
# the original after wrapping it in a CachingModelWrapper.
self._model_loaders[name] = (type(model), model.init_spec())
# Wrap model in caching wrapper and add it to the app
self._models[name] = caching.CachingModelWrapper(model, name,
cache_dir=data_dir)
self._models[name] = caching.CachingModelWrapper(
model, name, **self._caching_model_wrapper_kw
)

self._annotators: list[lit_components.Annotator] = annotators or []
self._saved_datapoints = {}
Expand Down
39 changes: 34 additions & 5 deletions lit_nlp/lib/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,24 @@ def load_from_disk(self):
class CachingModelWrapper(lit_model.ModelWrapper):
"""Wrapper to add per-example caching to a LIT model."""

def __init__(self,
model: lit_model.Model,
name: str,
cache_dir: Optional[str] = None):
def __init__(
self,
model: lit_model.Model,
name: str,
cache_dir: Optional[str] = None,
strict_id_validation: bool = False,
id_hash_fn: Optional[lit_dataset.IdFnType] = None,
):
"""Wrap a model to add caching.
Args:
model: a LIT model
name: name, used for logging and data files
cache_dir: if given, will load/save data to disk
strict_id_validation: if true, will re-compute hashes using id_hash_fn and
verify that they match the provided IDs. See b/293984290.
id_hash_fn: function of example --> string id, used by
strict_id_validation mode.
"""
super().__init__(model)
self._name = name
Expand All @@ -211,6 +219,13 @@ def __init__(self,
name, model.supports_concurrent_predictions, cache_dir)
self.load_cache()

self._strict_id_validation = strict_id_validation
if self._strict_id_validation:
assert (
id_hash_fn is not None
), "Must provide id_hash_fn to use strict_id_validation mode."
self._id_hash_fn = id_hash_fn

def load_cache(self):
self._cache.load_from_disk()

Expand All @@ -224,9 +239,19 @@ def key_fn(self, d) -> CacheKey:
return None
return (self._name, d_id)

def _validate_ids(self, inputs: Iterable[JsonDict]):
for ex in inputs:
if not (given_id := ex.get("_id")):
continue
if (computed_id := self._id_hash_fn(types.Input(ex))) != given_id:
raise ValueError(
f"Given id '{given_id}' does not match computed id '{computed_id}'"
f" for example {str(ex)}."
)

##
# For internal use
def fit_transform(self, inputs: Iterable[types.JsonDict]):
def fit_transform(self, inputs: Iterable[JsonDict]):
"""Cache projections from ProjectorModel dimensionality reducers."""
wrapped = self.wrapped
if not isinstance(wrapped, lit_model.ProjectorModel):
Expand All @@ -252,6 +277,10 @@ def predict(self,
progress_indicator: Optional[ProgressIndicator] = lambda x: x,
**kw) -> list[JsonDict]:
inputs_as_list = list(inputs)

if self._strict_id_validation:
self._validate_ids(inputs_as_list)

# Try to get results from the cache.
input_keys = [self.key_fn(d) for d in inputs_as_list]
if self._cache.pred_lock_key(input_keys):
Expand Down
4 changes: 4 additions & 0 deletions lit_nlp/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
config.enforce_dataset_fields_required = False
config.report_all = False

# Whether to re-compute example hashes before checking the cache.
# See b/293984290.
config.strict_cache_id_validation = False

import os
import pathlib
config.client_root = os.path.join(
Expand Down
7 changes: 7 additions & 0 deletions lit_nlp/server_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@
'If true, and validate is true, will report every issue in validation '
'as opposed to just the first.',
),
flags.DEFINE_bool(
'strict_cache_id_validation',
False,
'If true, will re-compute hashes of all examples before checking the'
' cache, and raise an error if any do not match the provided _id'
' field. See b/293984290.',
),
flags.DEFINE_string(
'client_root',
os.path.join(
Expand Down

0 comments on commit 774cdbc

Please sign in to comment.