Skip to content

Commit

Permalink
Adds dedicated /get_metrics HTTP API and updates UI accordingly.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 530881175
  • Loading branch information
RyanMullins authored and LIT team committed May 10, 2023
1 parent e7777ad commit 6ba1db8
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 87 deletions.
109 changes: 100 additions & 9 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,25 @@ def _get_preds(self,
Returns:
list[JsonDict] containing requested fields of model predictions
Raises:
KeyError: If `data` does not have an 'inputs' property.
TypeError: If one of entries in `requested_types` is not a valid LitType.
ValueError: If the model returns a different number of predictions than
the number of inputs.
"""
inputs = data['inputs']
preds = list(self._models[model].predict_with_metadata(
data['inputs'], dataset_name=dataset_name, **kw))
inputs, dataset_name=dataset_name, **kw))

num_preds = len(preds)
num_inputs = len(inputs)
if num_preds != num_inputs:
raise ValueError(
f'Different number of model predictions ({num_preds}) than inputs'
f' ({num_inputs}).'
)

if not requested_types and not requested_fields:
return preds

Expand Down Expand Up @@ -469,20 +485,94 @@ def _get_interpretations(
model,
self._datasets[dataset_name],
model_outputs=model_outputs,
config=data.get('config'))
config=data.get('config'),
)

def _get_metrics(
self,
data,
model: str,
dataset_name: str,
metrics: Optional[str] = None,
# TODO(b/278586715): Remove this parameter once linked bug is fixed.
do_predict: str = '1', # bool URL param; encoding as "0" / "1" is safer.
**unused_kw,
) -> types.JsonDict:
"""Run the specified Metrics components.
Args:
data: JSON parsed from the HTTP Request body containing the inputs
(required) and config (optional) for parameterizing the Metrics calls.
model: The name of the model loaded in LIT, used to fetch the model
predictions.
dataset_name: The name of the dataset containing the ground truth labels
for the provided inputs.
metrics: An optional comma-separated string of metrics to run, if None it
will run all Metrics loaded in this LitApp instance.
do_predict: If true (default), will fetch the model predictions in this
function using `_get_preds()` and pass them through to each Metrics
component's run function.
**unused_kw: Unused keyword arguments.
Returns:
A dictionary of metrics results where the keys are the name of the
Metrics component and the values are list of dictionaries containing the
prediction key (`pred_key`), the label key (`label_key`), and `metrics`
for that pair of keys as a `Mapping[str, float]`.
Raises:
KeyError: If a model, dataset, or metric with the specified name is not
loaded in the LitApp instance.
ValueError: If there are no inputs.
"""
inputs = data.get('inputs')
if not inputs:
raise ValueError('Metrics cannot be computed without inputs.')

if dataset_name not in self._datasets:
raise KeyError(f'No dataset named {dataset_name} loaded in LIT.')

if model not in self._models:
raise KeyError(f'No model named {model} loaded in LIT.')

if metrics:
metrics_to_run = tuple(m for m in metrics.split(',') if m)
unknown_metrics = [m for m in metrics_to_run if m not in self._metrics]
if unknown_metrics:
raise KeyError(f'Requested unknown metrics "{unknown_metrics}".')
else:
metrics_to_run = tuple(self._metrics.keys())

if utils.coerce_bool(do_predict):
model_outputs = self._get_preds(data, model, dataset_name)
else:
model_outputs = None

dataset = self._datasets[dataset_name]
model = self._models[model]
config = data.get('config')

return {
name: self._metrics[name].run_with_metadata(
inputs, model, dataset, model_outputs, config
)
for name in metrics_to_run
}

def _push_ui_state(self, data, dataset_name: str, **unused_kw):
"""Push UI state back to Python."""
if self.ui_state_tracker is None:
raise RuntimeError('Attempted to push UI state, but that is not enabled '
'for this server.')
raise RuntimeError(
'Attempted to push UI state, but that is not enabled for this server.'
)
options = data.get('config', {})
self.ui_state_tracker.update_state(data['inputs'],
self._datasets[dataset_name],
dataset_name, **options)
self.ui_state_tracker.update_state(
data['inputs'], self._datasets[dataset_name], dataset_name, **options
)

def _validate(self, validate: Optional[flag_helpers.ValidationMode],
report_all: bool):
def _validate(
self, validate: Optional[flag_helpers.ValidationMode], report_all: bool
):
"""Validate all datasets and models loaded for proper setup."""
if validate is None or validate == flag_helpers.ValidationMode.OFF:
return
Expand Down Expand Up @@ -753,6 +843,7 @@ def __init__(
# Model prediction endpoints.
'/get_preds': self._get_preds,
'/get_interpretations': self._get_interpretations,
'/get_metrics': self._get_metrics,
}
wrapped_handlers = {k: self.make_handler(v) for k, v in handlers.items()}
wrapped_handlers['/load_and_go'] = self._load_and_go
Expand Down
17 changes: 17 additions & 0 deletions lit_nlp/client/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,22 @@ export interface DatasetInfo {
description?: string;
}

export type SerializedDatasetInfo = {
// tslint:disable-next-line:no-any
[K in keyof DatasetInfo]: any;
};

export interface ComponentInfoMap {
[name: string]: ComponentInfo;
}

export interface SerializedComponentInfoMap {
[name: string]: {
// tslint:disable-next-line:no-any
[K in keyof ComponentInfo]: any;
};
}

export interface DatasetInfoMap {
[name: string]: DatasetInfo;
}
Expand All @@ -83,6 +95,11 @@ export interface ModelInfo {
description?: string;
}

export type SerializedModelInfo = {
// tslint:disable-next-line:no-any
[K in keyof ModelInfo]: any;
};

export interface ModelInfoMap {
[modelName: string]: ModelInfo;
}
Expand Down
57 changes: 27 additions & 30 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {unsafeHTML} from 'lit/directives/unsafe-html.js';

import {marked} from 'marked';
import {LitName, LitType, LitTypeTypesList, LitTypeWithParent, MulticlassPreds, LIT_TYPES_REGISTRY} from './lit_types';
import {CallConfig, FacetMap, LitMetadata, ModelInfoMap, SerializedLitMetadata, SerializedSpec, Spec} from './types';
import {CallConfig, FacetMap, LitMetadata, ModelInfoMap, SerializedComponentInfoMap, SerializedDatasetInfo, SerializedLitMetadata, SerializedModelInfo, SerializedSpec, Spec} from './types';

/** Calculates the mean for a list of numbers */
export function mean(values: number[]): number {
Expand Down Expand Up @@ -170,6 +170,13 @@ export function cloneSpec(spec: Spec): Spec {
return newSpec;
}

function deserializeComponentInfoMap(infoMap: SerializedComponentInfoMap) {
for (const info of Object.values(infoMap)) {
info.configSpec = deserializeLitTypesInSpec(info.configSpec);
info.metaSpec = deserializeLitTypesInSpec(info.metaSpec);
}
}

/**
* Converts serialized LitTypes within the LitMetadata into LitType instances.
*/
Expand All @@ -178,42 +185,32 @@ export function cloneSpec(spec: Spec): Spec {
export function deserializeLitTypesInLitMetadata(
metadata: SerializedLitMetadata): LitMetadata {

for (const model of Object.keys(metadata.models)) {
metadata.models[model].spec.input =
deserializeLitTypesInSpec(metadata.models[model].spec.input);
metadata.models[model].spec.output =
deserializeLitTypesInSpec(metadata.models[model].spec.output);
for (const info of Object.values(metadata.models)) {
const {spec} = info as SerializedModelInfo;
spec.input = deserializeLitTypesInSpec(spec.input as SerializedSpec);
spec.output = deserializeLitTypesInSpec(spec.output as SerializedSpec);
}

for (const dataset of Object.keys(metadata.datasets)) {
metadata.datasets[dataset].spec =
deserializeLitTypesInSpec(metadata.datasets[dataset].spec);
for (const info of Object.values(metadata.datasets)) {
const typedInfo = info as SerializedDatasetInfo;
typedInfo.spec =
deserializeLitTypesInSpec(typedInfo.spec as SerializedSpec);
}

for (const generator of Object.keys(metadata.generators)) {
metadata.generators[generator].configSpec =
deserializeLitTypesInSpec(metadata.generators[generator].configSpec);
metadata.generators[generator].metaSpec =
deserializeLitTypesInSpec(metadata.generators[generator].metaSpec);
}

for (const interpreter of Object.keys(metadata.interpreters)) {
metadata.interpreters[interpreter].configSpec = deserializeLitTypesInSpec(
metadata.interpreters[interpreter].configSpec);
metadata.interpreters[interpreter].metaSpec =
deserializeLitTypesInSpec(metadata.interpreters[interpreter].metaSpec);
}
deserializeComponentInfoMap(metadata.generators);
deserializeComponentInfoMap(metadata.interpreters);
deserializeComponentInfoMap(metadata.metrics);

for (const dataset of Object.keys(metadata.initSpecs.datasets)) {
if (metadata.initSpecs.datasets[dataset] == null) continue;
metadata.initSpecs.datasets[dataset] =
deserializeLitTypesInSpec(metadata.initSpecs.datasets[dataset]);
for (const [name, spec] of Object.entries(metadata.initSpecs.datasets)) {
if (spec == null) continue;
metadata.initSpecs.datasets[name] =
deserializeLitTypesInSpec(spec as SerializedSpec);
}

for (const model of Object.keys(metadata.initSpecs.models)) {
if (metadata.initSpecs.models[model] == null) continue;
metadata.initSpecs.models[model] =
deserializeLitTypesInSpec(metadata.initSpecs.models[model]);
for (const [name, spec] of Object.entries(metadata.initSpecs.models)) {
if (spec == null) continue;
metadata.initSpecs.models[name] =
deserializeLitTypesInSpec(spec as SerializedSpec);
}

return metadata;
Expand Down
Loading

0 comments on commit 6ba1db8

Please sign in to comment.