Skip to content

Commit

Permalink
Remove all remaining uses of run_with_metadata() throughout LIT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 552904163
  • Loading branch information
nadah09 authored and LIT team committed Aug 1, 2023
1 parent 5ed93bd commit 5047bdd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 26 deletions.
20 changes: 0 additions & 20 deletions lit_nlp/api/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,6 @@ def run(self,
'Subclass should implement this, or override run_with_metadata() directly.'
)

def run_with_metadata(self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None):
"""Run this component, with access to data indices and metadata."""
inputs = [ex['data'] for ex in indexed_inputs]
return self.run(inputs, model, dataset, model_outputs, config)

def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
"""Return if interpreter is compatible with the dataset and model."""
Expand Down Expand Up @@ -166,16 +156,6 @@ def run(
raise NotImplementedError(
'Subclass should implement its own run using compute.')

def run_with_metadata(
self,
indexed_inputs: Sequence[IndexedInput],
model: lit_model.Model,
dataset: lit_dataset.IndexedDataset,
model_outputs: Optional[list[JsonDict]] = None,
config: Optional[JsonDict] = None) -> list[JsonDict]:
inputs = [inp['data'] for inp in indexed_inputs]
return self.run(inputs, model, dataset, model_outputs, config)

# New methods introduced by this subclass

def is_field_compatible(
Expand Down
19 changes: 13 additions & 6 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,11 @@ def _get_generated(

dataset = self._datasets[dataset_name]
# Nested list, containing generated examples from each input.
all_generated: list[list[Input]] = genny.run_with_metadata( # pytype: disable=annotation-type-mismatch # always-use-return-annotations
data['inputs'], self._models[model], dataset, config=config)
all_generated: list[list[Input]] = genny.run( # pytype: disable=annotation-type-mismatch # always-use-return-annotations
[ex['data'] for ex in data['inputs']],
self._models[model],
dataset,
config=config)

# Annotate datapoints
def annotate_generated(datapoints):
Expand Down Expand Up @@ -590,8 +593,8 @@ def _get_interpretations(
else:
model_outputs = None

return interp.run_with_metadata(
data['inputs'],
return interp.run(
[ex['data'] for ex in data['inputs']],
mdl,
self._datasets[dataset_name],
model_outputs=model_outputs,
Expand Down Expand Up @@ -673,8 +676,12 @@ def _get_metrics(
config, config_spec, f'Metric {name}'
)

results[name] = metric.run_with_metadata(
inputs, mdl, dataset, model_outputs=model_outputs, config=config
results[name] = metric.run(
[ex['data'] for ex in inputs],
mdl,
dataset,
model_outputs=model_outputs,
config=config
)

return results
Expand Down

0 comments on commit 5047bdd

Please sign in to comment.