Skip to content

Commit

Permalink
Adds dataset and model init spec params to LitApp(...)
Browse files Browse the repository at this point in the history
Refactors prior change to move all init spec info to the top level of LitMetadata. This change was made because the addition of the LitApp(...) params makes the prior assumption that there would be a 1:1 correlation between named datasets/models and an init spec impossible to guarantee. Also, this has the added benefit of making init specs much easier to find and debug in the LitMetadata JSON structure :)

Demos will be updated to make use of these params in a subsequent revision.

PiperOrigin-RevId: 506286808
  • Loading branch information
RyanMullins authored and LIT team committed Feb 1, 2023
1 parent db51d9d commit f3b0d6e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 41 deletions.
41 changes: 27 additions & 14 deletions lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@

ProgressIndicator = Callable[[Iterable], Iterable]

DatasetLoader = tuple[Callable[..., lit_dataset.Dataset], Optional[types.Spec]]
DatasetLoadersMap = dict[str, DatasetLoader]

ModelLoader = tuple[Callable[..., lit_model.Model], Optional[types.Spec]]
ModelLoadersMap = dict[str, ModelLoader]


class LitApp(object):
"""LIT WSGI application."""
Expand All @@ -60,7 +66,6 @@ def _build_metadata(self):
info = {
'description': model.description(),
'spec': {
'initSpec': self._model_init_specs[name],
'input': model.input_spec(),
'output': model.output_spec(),
}
Expand Down Expand Up @@ -97,7 +102,6 @@ def _build_metadata(self):
dataset_info = {}
for name, ds in self._datasets.items():
dataset_info[name] = {
'init': self._dataset_init_specs[name],
'spec': ds.spec(),
'description': ds.description(),
'size': len(ds),
Expand All @@ -119,6 +123,11 @@ def _build_metadata(self):
'description': interpreter.description()
}

init_specs = {
'datasets': {n: s for n, (_, s) in self._dataset_loaders.items()},
'models': {n: s for n, (_, s) in self._model_loaders.items()},
}

return {
# Component info and specs
'models': model_info,
Expand All @@ -135,6 +144,7 @@ def _build_metadata(self):
'onboardStartDoc': self._onboard_start_doc,
'onboardEndDoc': self._onboard_end_doc,
'syncState': self.ui_state_tracker is not None,
'initSpecs': init_specs,
}

def _get_model_spec(self, name: str):
Expand Down Expand Up @@ -513,11 +523,14 @@ def __init__(
interpreters: Optional[Mapping[str, lit_components.Interpreter]] = None,
annotators: Optional[list[lit_components.Annotator]] = None,
layouts: Optional[layout.LitComponentLayouts] = None,
dataset_loaders: Optional[DatasetLoadersMap] = None,
model_loaders: Optional[ModelLoadersMap] = None,
# General server config; see server_flags.py.
data_dir: Optional[str] = None,
warm_start: float = 0.0,
warm_start_progress_indicator: Optional[ProgressIndicator] = tqdm
.tqdm, # not in server_flags
warm_start_progress_indicator: Optional[
ProgressIndicator
] = tqdm.tqdm, # not in server_flags
warm_projections: bool = False,
client_root: Optional[str] = None,
demo_mode: bool = False,
Expand Down Expand Up @@ -551,13 +564,13 @@ def __init__(
# client code to manually merge when this is the desired behavior.
self._layouts = dict(layout.DEFAULT_LAYOUTS, **(layouts or {}))

self._model_init_specs: dict[str, Optional[types.Spec]] = {}
self._model_loaders: ModelLoadersMap = model_loaders or {}
self._models: dict[str, caching.CachingModelWrapper] = {}
for name, model in models.items():
# We need to extract and store the results of the original
# model.init_spec() here so that we don't lose access to those fields
# after LIT wraps the model in a CachingModelWrapper.
self._model_init_specs[name] = model.init_spec()
if model_loaders is None:
# Attempt to infer an init spec for the model before we lose access to
# the original after wrapping it in a CachingModelWrapper.
self._model_loaders[name] = (type(model), model.init_spec())
# Wrap model in caching wrapper and add it to the app
self._models[name] = caching.CachingModelWrapper(model, name,
cache_dir=data_dir)
Expand All @@ -571,13 +584,13 @@ def __init__(
# dataset on the frontend.
tmp_datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)

self._dataset_init_specs: dict[str, Optional[types.Spec]] = {}
self._dataset_loaders: DatasetLoadersMap = dataset_loaders or {}
self._datasets: dict[str, lit_dataset.IndexedDataset] = {}
for name, ds in tmp_datasets.items():
# We need to extract and store the results of the original
# dataset.init_spec() here so that we don't lose access to those fields
# after LIT goes through the dataset annotation and indexing process.
self._dataset_init_specs[name] = ds.init_spec()
if dataset_loaders is None:
# Attempt to infer an init spec for the dataset before we lose access to
# the original during dataset annotation and indexing.
self._dataset_loaders[name] = (type(ds), ds.init_spec())
# Anotate the dataset
annotated_ds = self._run_annotators(ds)
# Index the annotated dataset and add it to the app
Expand Down
38 changes: 24 additions & 14 deletions lit_nlp/client/lib/testing_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ export const mockMetadata: LitMetadata = {
'models': {
'sst_0_micro': {
'spec': {
'init': {},
'input': {
'passage': createLitType(TextSegment),
'passage_tokens':
Expand Down Expand Up @@ -90,7 +89,6 @@ export const mockMetadata: LitMetadata = {
},
'sst_1_micro': {
'spec': {
'init': {},
'input': {
'passage': createLitType(TextSegment),
'passage_tokens':
Expand Down Expand Up @@ -125,17 +123,13 @@ export const mockMetadata: LitMetadata = {
},
'datasets': {
'sst_dev': {
'initSpec': {
'split': createLitType(StringLitType),
},
'size': 872,
'spec': {
'passage': createLitType(TextSegment),
'label': createLitType(CategoryLabel, {'vocab': ['0', '1']}),
}
},
'color_test': {
'initSpec': null,
'size': 2,
'spec': {
'testNumFeat0': createLitType(Scalar),
Expand All @@ -145,7 +139,6 @@ export const mockMetadata: LitMetadata = {
}
},
'penguin_dev': {
'initSpec': {},
'size': 10,
'spec': {
'body_mass_g': createLitType(Scalar, {
Expand Down Expand Up @@ -189,6 +182,17 @@ export const mockMetadata: LitMetadata = {
'pca': emptySpec(),
'umap': emptySpec(),
},
'initSpecs': {
'datasets': {
'sst_dev': {'split': createLitType(StringLitType)},
'color_test': null,
'penguin_dev': {}
},
'models': {
'sst_0_micro': {},
'sst_1_micro': {},
},
},
'layouts': {},
'demoMode': false,
'defaultLayout': 'default',
Expand All @@ -204,7 +208,6 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
'models': {
'sst_0_micro': {
'spec': {
'init': {},
'input': {
'passage': {'__name__': 'TextSegment', 'required': true},
'passage_tokens':
Expand Down Expand Up @@ -251,7 +254,6 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
},
'sst_1_micro': {
'spec': {
'init': {},
'input': {
'passage': {'__name__': 'TextSegment', 'required': true},
'passage_tokens':
Expand Down Expand Up @@ -299,9 +301,6 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
},
'datasets': {
'sst_dev': {
'initSpec': {
'split':{'__name__': 'StringLitType', 'required': true}
},
'size': 872,
'spec': {
'passage': {'__name__': 'TextSegment', 'required': true},
Expand All @@ -310,7 +309,6 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
}
},
'color_test': {
'initSpec': null,
'size': 2,
'spec': {
'testNumFeat0': {'__name__': 'Scalar', 'required': true},
Expand All @@ -328,7 +326,6 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
}
},
'penguin_dev': {
'initSpec': {},
'size': 10,
'spec': {
'body_mass_g': {'__name__': 'Scalar', 'step': 1, 'required': true},
Expand Down Expand Up @@ -378,6 +375,19 @@ export const mockSerializedMetadata: SerializedLitMetadata = {
'pca': emptySpec(),
'umap': emptySpec(),
},
'initSpecs': {
'datasets': {
'sst_dev': {
'split': {'__name__': 'StringLitType', 'required': true}
},
'color_test': null,
'penguin_dev': {}
},
'models': {
'sst_0_micro': {},
'sst_1_micro': {},
},
},
'layouts': {},
'demoMode': false,
'defaultLayout': 'default',
Expand Down
10 changes: 8 additions & 2 deletions lit_nlp/client/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ export interface SerializedSpec {
[key: string]: {__name__: string};
}

interface InitSpecMap {
[name: string]: Spec|null; // using null here because None ==> null in Python
}

export interface ComponentInfo {
configSpec: Spec;
metaSpec: Spec;
Expand All @@ -46,7 +50,6 @@ export interface ComponentInfo {

export interface DatasetInfo {
size: number;
initSpec: Spec | null; // using null here because None ==> null in Python
spec: Spec;
description?: string;
}
Expand All @@ -65,7 +68,6 @@ export interface CallConfig {
}

export interface ModelSpec {
init: Spec | null; // using null here because None ==> null in Python
input: Spec;
output: Spec;
}
Expand Down Expand Up @@ -96,6 +98,10 @@ export interface LitMetadata {
onboardStartDoc?: string;
onboardEndDoc?: string;
syncState: boolean;
initSpecs: {
datasets: InitSpecMap;
models: InitSpecMap;
};
}

/**
Expand Down
24 changes: 18 additions & 6 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,19 @@ export function cloneSpec(spec: Spec): Spec {
/**
* Converts serialized LitTypes within the LitMetadata into LitType instances.
*/
// TODO(b/267200697): Explore optimizing this function using the reviver
// parameter of JSON.parse().
export function deserializeLitTypesInLitMetadata(
metadata: SerializedLitMetadata): LitMetadata {

for (const model of Object.keys(metadata.models)) {
const {spec} = metadata.models[model];
spec.init = spec.init ? deserializeLitTypesInSpec(spec.init) : null;
spec.input = deserializeLitTypesInSpec(spec.input);
spec.output = deserializeLitTypesInSpec(spec.output);
metadata.models[model].spec.input =
deserializeLitTypesInSpec(metadata.models[model].spec.input);
metadata.models[model].spec.output =
deserializeLitTypesInSpec(metadata.models[model].spec.output);
}

for (const dataset of Object.keys(metadata.datasets)) {
metadata.datasets[dataset].initSpec = metadata.datasets[dataset].initSpec ?
deserializeLitTypesInSpec(metadata.datasets[dataset].initSpec) : null;
metadata.datasets[dataset].spec =
deserializeLitTypesInSpec(metadata.datasets[dataset].spec);
}
Expand All @@ -204,6 +204,18 @@ export function deserializeLitTypesInLitMetadata(
deserializeLitTypesInSpec(metadata.interpreters[interpreter].metaSpec);
}

for (const dataset of Object.keys(metadata.initSpecs.datasets)) {
if (metadata.initSpecs.datasets[dataset] == null) continue;
metadata.initSpecs.datasets[dataset] =
deserializeLitTypesInSpec(metadata.initSpecs.datasets[dataset]);
}

for (const model of Object.keys(metadata.initSpecs.models)) {
if (metadata.initSpecs.models[model] == null) continue;
metadata.initSpecs.models[model] =
deserializeLitTypesInSpec(metadata.initSpecs.models[model]);
}

return metadata;
}

Expand Down
5 changes: 0 additions & 5 deletions lit_nlp/client/services/classification_service_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ MULTICLASS_PRED_WITH_THRESHOLD.null_idx = 0;
MULTICLASS_PRED_WITH_THRESHOLD.vocab = ['0', '1'];
MULTICLASS_PRED_WITH_THRESHOLD.threshold = 0.3;
const MULTICLASS_SPEC_WITH_THRESHOLD: ModelSpec = {
init: null,
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_WITH_THRESHOLD}
};
Expand All @@ -22,29 +21,25 @@ const MULTICLASS_PRED_WITHOUT_THRESHOLD = new MulticlassPreds();
MULTICLASS_PRED_WITHOUT_THRESHOLD.null_idx = 0;
MULTICLASS_PRED_WITHOUT_THRESHOLD.vocab = ['0', '1'];
const MULTICLASS_SPEC_WITHOUT_THRESHOLD: ModelSpec = {
init: null,
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_WITHOUT_THRESHOLD}
};

const MULTICLASS_PRED_NO_VOCAB = new MulticlassPreds();
MULTICLASS_PRED_NO_VOCAB.null_idx = 0;
const INVALID_SPEC_NO_VOCAB: ModelSpec = {
init: null,
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_NO_VOCAB}
};

const MULTICLASS_PRED_NO_NULL_IDX = new MulticlassPreds();
MULTICLASS_PRED_NO_NULL_IDX.vocab = ['0', '1'];
const INVALID_SPEC_NO_NULL_IDX: ModelSpec = {
init: null,
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_NO_NULL_IDX}
};

const INVALID_SPEC_NO_MULTICLASS_PRED: ModelSpec = {
init: null,
input: {},
output: {}
};
Expand Down

0 comments on commit f3b0d6e

Please sign in to comment.