From 8c6ac1174cd1020c00491736a3d0fa78e05e0eed Mon Sep 17 00:00:00 2001 From: Crystal Qian Date: Wed, 27 Jul 2022 11:28:23 -0700 Subject: [PATCH] Replace hard-coded type approximations with references to the new class-based LitTypes. PiperOrigin-RevId: 463636591 --- lit_nlp/api/types.py | 10 +- lit_nlp/api/types_test.py | 6 -- .../client/elements/interpreter_controls.ts | 69 +++++++------- lit_nlp/client/lib/generated_text_utils.ts | 5 +- .../client/lib/generated_text_utils_test.ts | 3 +- lit_nlp/client/lib/lit_types.ts | 34 +++++-- lit_nlp/client/lib/lit_types_utils.ts | 95 +++++++++++++++++-- lit_nlp/client/lib/lit_types_utils_test.ts | 45 ++++++++- lit_nlp/client/lib/types.ts | 44 +-------- lit_nlp/client/lib/utils.ts | 16 ++-- lit_nlp/client/modules/attention_module.ts | 6 +- .../client/modules/classification_module.ts | 26 +++-- .../counterfactual_explainer_module.ts | 4 +- lit_nlp/client/modules/curves_module.ts | 18 ++-- lit_nlp/client/modules/data_table_module.ts | 5 +- .../client/modules/datapoint_editor_module.ts | 13 +-- .../modules/feature_attribution_module.ts | 22 +++-- .../client/modules/generated_image_module.ts | 16 ++-- .../client/modules/generated_text_module.ts | 5 +- lit_nlp/client/modules/generator_module.ts | 48 +++++----- .../client/modules/lm_prediction_module.ts | 11 ++- lit_nlp/client/modules/multilabel_module.ts | 4 +- lit_nlp/client/modules/pdp_module.ts | 9 +- lit_nlp/client/modules/regression_module.ts | 4 +- .../modules/salience_clustering_module.ts | 15 +-- lit_nlp/client/modules/salience_map_module.ts | 18 ++-- lit_nlp/client/modules/scalar_module.ts | 14 ++- .../modules/sequence_salience_module.ts | 4 +- lit_nlp/client/modules/span_graph_module.ts | 8 +- lit_nlp/client/modules/tcav_module.ts | 5 +- lit_nlp/client/modules/tda_module.ts | 28 +++--- lit_nlp/client/services/api_service.ts | 5 +- .../client/services/classification_service.ts | 7 +- lit_nlp/client/services/data_service.ts | 33 ++++--- lit_nlp/client/services/group_service.ts | 10 +- lit_nlp/client/services/state_service.ts | 18 ++-- 36 files changed, 414 insertions(+), 269 deletions(-) diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py index 34951216..f658eb51 100644 --- a/lit_nlp/api/types.py +++ b/lit_nlp/api/types.py @@ -486,26 +486,26 @@ class MultiFieldMatcher(LitType): @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class _Salience(LitType): +class Salience(LitType): """Metadata about a returned salience map.""" autorun: bool = False # If the saliency technique is automatically run. signed: bool # If the returned values are signed. @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class TokenSalience(_Salience): +class TokenSalience(Salience): """Metadata about a returned token salience map.""" default: dtypes.TokenSalience = None @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class FeatureSalience(_Salience): +class FeatureSalience(Salience): """Metadata about a returned feature salience map.""" default: dtypes.FeatureSalience = None @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class ImageSalience(_Salience): +class ImageSalience(Salience): """Metadata about a returned image saliency. The data is returned as an image in the base64 URL encoded format, e.g., @@ -515,7 +515,7 @@ class ImageSalience(_Salience): @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class SequenceSalience(_Salience): +class SequenceSalience(Salience): """Metadata about a returned sequence salience map.""" default: dtypes.SequenceSalienceMap = None diff --git a/lit_nlp/api/types_test.py b/lit_nlp/api/types_test.py index ad7bfb31..61e17c22 100644 --- a/lit_nlp/api/types_test.py +++ b/lit_nlp/api/types_test.py @@ -3,15 +3,9 @@ from lit_nlp.api import types from google3.testing.pybase import googletest -NUM_TYPES = 41 - class TypesTest(googletest.TestCase): - def test_num_littypes(self): - num_types = len(types.all_littypes()) - self.assertEqual(num_types, NUM_TYPES) - def test_inherit_parent_default_type(self): lit_type = types.StringLitType() self.assertIsInstance(lit_type.default, str) diff --git a/lit_nlp/client/elements/interpreter_controls.ts b/lit_nlp/client/elements/interpreter_controls.ts index 2211dfaa..ba1d90b4 100644 --- a/lit_nlp/client/elements/interpreter_controls.ts +++ b/lit_nlp/client/elements/interpreter_controls.ts @@ -19,14 +19,14 @@ import './checkbox'; import '@material/mwc-icon'; -import {property} from 'lit/decorators'; -import {customElement} from 'lit/decorators'; -import { html} from 'lit'; +import {html} from 'lit'; +import {customElement, property} from 'lit/decorators'; import {observable} from 'mobx'; import {ReactiveElement} from '../lib/elements'; +import {CategoryLabel, FieldMatcher, LitType, LitTypeWithVocab, MultiFieldMatcher, Scalar, SparseMultilabel} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; -import {LitType, Spec} from '../lib/types'; +import {Spec} from '../lib/types'; import {isLitSubtype} from '../lib/utils'; import {styles} from './interpreter_controls.css'; @@ -101,23 +101,25 @@ export class InterpreterControls extends ReactiveElement { renderControls() { const spec = this.spec as Spec; return Object.keys(spec).map(name => { - // Ensure a default value for any of the options provided for setting. if (this.settings[name] == null) { - if (isLitSubtype(spec[name], 'SparseMultilabel')) { - this.settings[name] = spec[name].default as string[]; + if (spec[name] instanceof SparseMultilabel) { + this.settings[name] = (spec[name] as SparseMultilabel).default; } // If select all is True, default value is all of vocab. - if (isLitSubtype(spec[name], 'MultiFieldMatcher')) { - this.settings[name] = spec[name].select_all!? - spec[name].vocab as string[] : - spec[name].default as string[]; + if (spec[name] instanceof MultiFieldMatcher) { + const fieldSpec = spec[name] as MultiFieldMatcher; + this.settings[name] = fieldSpec.select_all ? + fieldSpec.vocab as string[] : + fieldSpec.default; } // FieldMatcher has its vocab set outside of this element. - else if (isLitSubtype(spec[name], ['CategoryLabel', 'FieldMatcher'])) { + else if ( + spec[name] instanceof CategoryLabel || + spec[name] instanceof FieldMatcher) { + const {vocab} = spec[name] as LitTypeWithVocab; this.settings[name] = - spec[name].vocab != null && spec[name].vocab!.length > 0 ? - spec[name].vocab![0] : ''; + vocab != null && vocab.length > 0 ? vocab[0] : ''; } else { this.settings[name] = spec[name].default as string; } @@ -126,7 +128,7 @@ export class InterpreterControls extends ReactiveElement { return html`
- ${(required ? '*':'') + name} + ${(required ? '*' : '') + name}
${this.renderControl(name, spec[name])}
`; @@ -134,21 +136,22 @@ export class InterpreterControls extends ReactiveElement { } renderControl(name: string, controlType: LitType) { - if (isLitSubtype(controlType, ['SparseMultilabel', 'MultiFieldMatcher'])) { + if (controlType instanceof SparseMultilabel || + controlType instanceof MultiFieldMatcher) { + const {vocab} = controlType as LitTypeWithVocab; // Render checkboxes, with the first item selected. - const renderCheckboxes = - () => controlType.vocab!.map(option => { + const renderCheckboxes = () => vocab.map(option => { // tslint:disable-next-line:no-any const change = (e: any) => { if (e.target.checked) { (this.settings[name] as string[]).push(option); } else { - this.settings[name] = (this.settings[name] as string[]).filter( - item => item !== option); + this.settings[name] = (this.settings[name] as string[]) + .filter(item => item !== option); } }; - const isSelected = (this.settings[name] as string[]).indexOf( - option) !== -1; + const isSelected = + (this.settings[name] as string[]).indexOf(option) !== -1; return html` @@ -156,29 +159,31 @@ export class InterpreterControls extends ReactiveElement { `; }); return html`
${renderCheckboxes()}
`; - } else if (isLitSubtype(controlType, ['CategoryLabel', 'FieldMatcher'])) { + } else if ( + controlType instanceof CategoryLabel || + controlType instanceof FieldMatcher) { + const {vocab} = controlType as LitTypeWithVocab; // Render a dropdown, with the first item selected. const updateDropdown = (e: Event) => { const select = (e.target as HTMLSelectElement); - this.settings[name] = controlType.vocab![select?.selectedIndex || 0]; + this.settings[name] = vocab[select?.selectedIndex || 0]; }; - const options = controlType.vocab!.map((option, optionIndex) => { + const options = vocab.map((option, optionIndex) => { return html` `; }); const defaultValue = - controlType.vocab != null && controlType.vocab.length > 0 ? - controlType.vocab[0] : ''; + vocab != null && vocab.length > 0 ? + vocab[0] : + ''; return html``; - } else if (isLitSubtype(controlType, ['Scalar'])) { + } else if (controlType instanceof Scalar) { // Render a slider. - const step = controlType.step!; - const minVal = controlType.min_val!; - const maxVal = controlType.max_val!; + const {step, min_val: minVal, max_val: maxVal} = controlType; const updateSettings = (e: Event) => { const input = (e.target as HTMLInputElement); diff --git a/lit_nlp/client/lib/generated_text_utils.ts b/lit_nlp/client/lib/generated_text_utils.ts index e1bf6b15..01c6db88 100644 --- a/lit_nlp/client/lib/generated_text_utils.ts +++ b/lit_nlp/client/lib/generated_text_utils.ts @@ -20,7 +20,8 @@ */ import difflib from 'difflib'; -import {GeneratedTextCandidate, IndexedInput, Input, LitName, Preds, Spec} from './types'; +import {LitName, LitTypeWithParent} from './lit_types'; +import {GeneratedTextCandidate, IndexedInput, Input, Preds, Spec} from './types'; import {findSpecKeys, isLitSubtype} from './utils'; // tslint:disable-next-line:no-any difflib does not support Closure imports @@ -89,7 +90,7 @@ export function getAllReferenceTexts( // Search input fields: anything referenced in model's output spec const inputReferenceKeys = new Set(); for (const outKey of findSpecKeys(outputSpec, GENERATION_TYPES)) { - const parent = outputSpec[outKey].parent; + const {parent} = outputSpec[outKey] as LitTypeWithParent; if (parent && dataSpec[parent]) { inputReferenceKeys.add(parent); } diff --git a/lit_nlp/client/lib/generated_text_utils_test.ts b/lit_nlp/client/lib/generated_text_utils_test.ts index 49fde372..eaf45d5a 100644 --- a/lit_nlp/client/lib/generated_text_utils_test.ts +++ b/lit_nlp/client/lib/generated_text_utils_test.ts @@ -18,8 +18,9 @@ import 'jasmine'; import {canonicalizeGenerationResults, getAllOutputTexts, getAllReferenceTexts, getFlatTexts, getTextDiff} from './generated_text_utils'; +import {LitType} from './lit_types'; import {createLitType} from './lit_types_utils'; -import {Input, LitType, Preds, Spec} from './types'; +import {Input, Preds, Spec} from './types'; function textSegmentType(): LitType { return createLitType('TextSegment', { diff --git a/lit_nlp/client/lib/lit_types.ts b/lit_nlp/client/lib/lit_types.ts index bb4a4ff1..0dfa9b4d 100644 --- a/lit_nlp/client/lib/lit_types.ts +++ b/lit_nlp/client/lib/lit_types.ts @@ -29,14 +29,15 @@ function registered(target: any) { REGISTRY[target.name] = target; } -const registryKeys = Object.keys(REGISTRY) as ReadonlyArray; +const registryKeys : string[] = Object.keys(REGISTRY); /** * The types of all LitTypes in the registry, e.g. * 'StringLitType' | 'TextSegment' ... */ export type LitName = typeof registryKeys[number]; -type LitClass = 'LitType'; +/** A type alias for the LitType class. */ +export type LitClass = 'LitType'; type ScoredTextCandidates = Array<[text: string, score: number|null]>; /** @@ -56,12 +57,26 @@ export class LitType { // TODO(b/162269499): Replace this with `unknown` after migration. // tslint:disable-next-line:no-any readonly default: any|undefined = null; + // If this type is created from an Annotator. + annotated: boolean = false; // TODO(b/162269499): Update to camel case once we've replaced old LitType. show_in_data_table: boolean = false; // TODO(b/162269499): Add isCompatible functionality. } +/** A type alias for LitType with an align property. */ +export type LitTypeWithAlign = LitType&{align: string}; + +/** A type alias for LitType with a null idx property. */ +export type LitTypeWithNullIdx = LitType&{null_idx: number}; + +/** A type alias for LitType with a parent property. */ +export type LitTypeWithParent = LitType&{parent: string}; + +/** A type alias for LitType with a vocab property. */ +export type LitTypeWithVocab = LitType&{vocab: string[]}; + /** * A string LitType. */ @@ -420,6 +435,10 @@ export class SparseMultilabelPreds extends _StringCandidateList { parent?: string = undefined; } +// TODO(b/162269499): Rename FieldMatcher to SingleFieldMatcher. +/** A type alias for FieldMatcher or MultiFieldMatcher. */ +export type LitTypeOfFieldMatcher = FieldMatcher|MultiFieldMatcher; + /** * For matching spec fields. * @@ -462,7 +481,8 @@ export class MultiFieldMatcher extends LitType { /** * Metadata about a returned salience map. */ -class _Salience extends LitType { +@registered +export class Salience extends LitType { /** If the saliency technique is automatically run. */ autorun: boolean = false; /** If the returned values are signed. */ @@ -473,14 +493,14 @@ class _Salience extends LitType { * Metadata about a returned token salience map. */ @registered -export class TokenSalience extends _Salience { +export class TokenSalience extends Salience { } /** * Metadata about a returned feature salience map. */ @registered -export class FeatureSalience extends _Salience { +export class FeatureSalience extends Salience { // TODO(b/162269499): Add Typescript dtypes so that we can set default types. } @@ -490,14 +510,14 @@ export class FeatureSalience extends _Salience { * data:image/jpg;base64,w4J3k1Bfa... */ @registered -export class ImageSalience extends _Salience { +export class ImageSalience extends Salience { } /** * Metadata about a returned sequence salience map. */ @registered -export class SequenceSalience extends _Salience { +export class SequenceSalience extends Salience { } /** diff --git a/lit_nlp/client/lib/lit_types_utils.ts b/lit_nlp/client/lib/lit_types_utils.ts index 2899054c..ac6b0788 100644 --- a/lit_nlp/client/lib/lit_types_utils.ts +++ b/lit_nlp/client/lib/lit_types_utils.ts @@ -1,6 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// For consistency with types.ts. +// tslint:disable: enforce-name-casing + import {Spec} from '../lib/types'; import {LitName, LitType, REGISTRY} from './lit_types'; +import {LitMetadata} from './types'; + /** * Creates and returns a new LitType instance. @@ -9,17 +31,20 @@ import {LitName, LitType, REGISTRY} from './lit_types'; * For example, {'show_in_data_table': true}. */ export function createLitType( - typeName: LitName, - constructorParams: {[key: string]: unknown} = {}){ + typeName: LitName, constructorParams: {[key: string]: unknown} = {}) { const litType = REGISTRY[typeName]; - // tslint:disable-next-line:no-any const newType = new (litType as any)(); newType.__name__ = typeName; newType.__mro__ = getMethodResolutionOrder(newType); + // Excluded properties are passed through in the Python serialization + // of LitTypes and can be ignored by the frontend. + const excluded = ['__mro__']; for (const key in constructorParams) { - if (key in newType) { + if (excluded.includes(key)) { + continue; + } else if (key in newType) { newType[key] = constructorParams[key]; } else { throw new Error( @@ -30,6 +55,60 @@ export function createLitType( return newType; } + +interface SerializedSpec { + [key: string]: {__name__: string}; +} + +/** + * Converts serialized LitTypes within a Spec into LitType instances. + */ +export function deserializeLitTypesInSpec(serializedSpec: SerializedSpec): Spec { + const typedSpec: Spec = {}; + for (const key of Object.keys(serializedSpec)) { + typedSpec[key] = + createLitType(serializedSpec[key].__name__, serializedSpec[key] as {}); + } + return typedSpec; +} + + +/** + * Converts serialized LitTypes within the LitMetadata into LitType instances. + */ +export function deserializeLitTypesInLitMetadata(metadata: LitMetadata): + LitMetadata { + for (const model of Object.keys(metadata.models)) { + metadata.models[model].spec.input = + deserializeLitTypesInSpec(metadata.models[model].spec.input); + metadata.models[model].spec.output = + deserializeLitTypesInSpec(metadata.models[model].spec.output); + } + + for (const dataset of Object.keys(metadata.datasets)) { + metadata.datasets[dataset].spec = + deserializeLitTypesInSpec(metadata.datasets[dataset].spec); + } + + for (const generator of Object.keys(metadata.generators)) { + metadata.generators[generator].configSpec = + deserializeLitTypesInSpec(metadata.generators[generator].configSpec); + metadata.generators[generator].metaSpec = + deserializeLitTypesInSpec(metadata.generators[generator].metaSpec); + } + + for (const interpreter of Object.keys(metadata.interpreters)) { + metadata.interpreters[interpreter].configSpec = deserializeLitTypesInSpec( + metadata.interpreters[interpreter].configSpec); + metadata.interpreters[interpreter].metaSpec = + deserializeLitTypesInSpec(metadata.interpreters[interpreter].metaSpec); + } + + metadata.littypes = deserializeLitTypesInSpec(metadata.littypes); + return metadata; +} + + /** * Returns the method resolution order for a given litType. * This is for compatability with references to non-class-based LitTypes, @@ -53,8 +132,7 @@ export function getMethodResolutionOrder(litType: LitType): string[] { * @param litType: The LitType to check. * @param typesToFind: Either a single or list of parent LitType candidates. */ -export function isLitSubtype( - litType: LitType, typesToFind: LitName|LitName[]) { +export function isLitSubtype(litType: LitType, typesToFind: LitName|LitName[]) { if (litType == null) return false; if (typeof typesToFind === 'string') { @@ -63,7 +141,7 @@ export function isLitSubtype( for (const typeName of typesToFind) { // tslint:disable-next-line:no-any - const registryType : any = REGISTRY[typeName]; + const registryType: any = REGISTRY[typeName]; if (litType instanceof registryType) { return true; @@ -80,6 +158,5 @@ export function isLitSubtype( export function findSpecKeys( spec: Spec, typesToFind: LitName|LitName[]): string[] { return Object.keys(spec).filter( - key => isLitSubtype( - spec[key] as LitType, typesToFind)); + key => isLitSubtype(spec[key], typesToFind)); } diff --git a/lit_nlp/client/lib/lit_types_utils_test.ts b/lit_nlp/client/lib/lit_types_utils_test.ts index 800aefc0..d239618d 100644 --- a/lit_nlp/client/lib/lit_types_utils_test.ts +++ b/lit_nlp/client/lib/lit_types_utils_test.ts @@ -16,7 +16,7 @@ describe('createLitType test', () => { const result = litTypesUtils.createLitType('Scalar'); expect(result).toEqual(expected); expect(result instanceof litTypes.Scalar).toEqual(true); - }); + }); it('creates with constructor params', () => { const expected = new litTypes.StringLitType(); @@ -89,8 +89,42 @@ describe('isLitSubtype test', () => { }); +describe('deserializeLitTypesInSpec test', () => { + // TODO(b/162269499): Add test for deserializeLitTypesInLitMetadata. + const testSpec = { + 'probabilities': { + '__class__': 'LitType', + '__name__': 'MulticlassPreds', + '__mro__': ['MulticlassPreds', 'LitType', 'object'], + 'required': true, + 'vocab': ['0', '1'], + 'null_idx': 0, + 'parent': 'label' + }, + 'pooled_embs': { + '__class__': 'LitType', + '__name__': 'Embeddings', + '__mro__': ['Embeddings', 'LitType', 'object'], + 'required': true + } + }; + + it('returns serialized littypes', () => { + expect(testSpec['probabilities'] instanceof litTypes.MulticlassPreds) + .toBe(false); + const result = litTypesUtils.deserializeLitTypesInSpec(testSpec); + expect(result['probabilities']) + .toEqual(litTypesUtils.createLitType( + 'MulticlassPreds', + {'vocab': ['0', '1'], 'null_idx': 0, 'parent': 'label'})); + expect(result['probabilities'] instanceof litTypes.MulticlassPreds) + .toBe(true); + }); +}); + + describe('findSpecKeys test', () => { - // TODO(cjqian): Add original utils_test test after adding more types. + // TODO(cjqian): Add original litTypesUtils.test test after adding more types. const spec: Spec = { 'scalar_foo': new litTypes.Scalar(), 'segment': new litTypes.StringLitType(), @@ -105,9 +139,9 @@ describe('findSpecKeys test', () => { ]); // Keys are in spec. - expect(litTypesUtils.findSpecKeys(spec, ['StringLitType', 'Scalar'])).toEqual([ - 'scalar_foo', 'segment', 'generated_text' - ]); + expect(litTypesUtils.findSpecKeys(spec, [ + 'StringLitType', 'Scalar' + ])).toEqual(['scalar_foo', 'segment', 'generated_text']); }); it('handles empty spec keys', () => { @@ -118,4 +152,5 @@ describe('findSpecKeys test', () => { expect(() => litTypesUtils.findSpecKeys(spec, '')).toThrowError(); expect(() => litTypesUtils.findSpecKeys(spec, 'NotAType')).toThrowError(); }); + }); diff --git a/lit_nlp/client/lib/types.ts b/lit_nlp/client/lib/types.ts index 039e22bd..423b392f 100644 --- a/lit_nlp/client/lib/types.ts +++ b/lit_nlp/client/lib/types.ts @@ -19,56 +19,16 @@ import * as d3 from 'd3'; import {TemplateResult} from 'lit'; +import {LitName, LitType} from './lit_types'; import {chunkWords, isLitSubtype} from './utils'; // tslint:disable-next-line:no-any export type D3Selection = d3.Selection; -export type LitClass = 'LitType'; -export type LitName = 'type'|'LitType'|'StringLitType'|'TextSegment'| - 'GeneratedText'|'GeneratedTextCandidates'|'ReferenceTexts'|'URL'| - 'SearchQuery'|'Tokens'|'TokenTopKPreds'|'Scalar'|'RegressionScore'| - 'CategoryLabel'|'MulticlassPreds'|'SequenceTags'|'SpanLabels'|'EdgeLabels'| - 'MultiSegmentAnnotations'|'Embeddings'|'TokenGradients'|'TokenEmbeddings'| - 'AttentionHeads'|'SparseMultilabel'|'FieldMatcher'|'MultiFieldMatcher'| - 'Gradients'|'Boolean'|'TokenSalience'|'ImageBytes'|'SparseMultilabelPreds'| - 'ImageGradients'|'ImageSalience'|'SequenceSalience'|'ReferenceScores'| - 'FeatureSalience'|'TopTokens'|'CurveDataPoints'|'InfluentialExamples'| - 'GeneratedURL'; - +// TODO(b/162269499): Replace this with class-based lists. export const listFieldTypes: LitName[] = ['Tokens', 'SequenceTags', 'SpanLabels', 'EdgeLabels', 'SparseMultilabel']; -export interface LitType { - __class__: LitClass|'type'; - __name__: LitName; - __mro__: string[]; - parent?: string; - align?: string; - align_in?: string; - align_out?: string; - vocab?: string[]; - null_idx?: number; - required?: boolean; - annotated?: boolean; - default?: string|string[]|number|number[]; - spec?: string; - types?: LitName|LitName[]; - min_val?: number; - max_val?: number; - step?: number; - exclusive?: boolean; - background?: boolean; - separator?: string; - autorun?: boolean; - signed?: boolean; - mask_token?: string; - token_prefix?: string; - select_all?: boolean; - autosort?: boolean; - show_in_data_table?: boolean; -} - export interface Spec { [key: string]: LitType; } diff --git a/lit_nlp/client/lib/utils.ts b/lit_nlp/client/lib/utils.ts index 06eb2910..c53cabbd 100644 --- a/lit_nlp/client/lib/utils.ts +++ b/lit_nlp/client/lib/utils.ts @@ -25,7 +25,8 @@ import {html, TemplateResult} from 'lit'; import {unsafeHTML} from 'lit/directives/unsafe-html.js'; import {marked} from 'marked'; -import {FacetMap, LitName, LitType, ModelInfoMap, Spec} from './types'; +import {LitName, LitType, LitTypeWithParent, MulticlassPreds} from './lit_types'; +import {FacetMap, ModelInfoMap, Spec} from './types'; /** Calculates the mean for a list of numbers */ export function mean(values: number[]): number { @@ -305,14 +306,17 @@ export function doesInputSpecContain( /** Returns if a LitType specifies binary classification. */ export function isBinaryClassification(litType: LitType) { - const predictionLabels = litType.vocab!; - const nullIdx = litType.null_idx; - return predictionLabels.length === 2 && nullIdx != null; + if (litType instanceof MulticlassPreds) { + const {vocab, null_idx: nullIdx} = litType; + return vocab.length === 2 && nullIdx != null; + } + + return false; } /** Returns if a LitType has a parent field. */ export function hasParent(litType: LitType) { - return litType.parent != null; + return (litType as LitTypeWithParent).parent != null; } /** @@ -518,4 +522,4 @@ export function linearSpace( values.push(minValue + i * step); } return values; -} \ No newline at end of file +} diff --git a/lit_nlp/client/modules/attention_module.ts b/lit_nlp/client/modules/attention_module.ts index de517036..6dda327d 100644 --- a/lit_nlp/client/modules/attention_module.ts +++ b/lit_nlp/client/modules/attention_module.ts @@ -27,6 +27,7 @@ import {observable} from 'mobx'; import {app} from '../core/app'; import {LitModule} from '../core/lit_module'; +import {AttentionHeads as AttentionHeadsLitType} from '../lib/lit_types'; import {IndexedInput, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys, getTextWidth, getTokOffsets, sumArray} from '../lib/utils'; import {FocusService} from '../services/services'; @@ -135,13 +136,14 @@ export class AttentionModule extends LitModule { private renderAttnHead() { const outputSpec = this.appState.currentModelSpecs[this.model].spec.output; - const fieldSpec = outputSpec[this.selectedLayer!]; + const fieldSpec = + outputSpec[this.selectedLayer!] as AttentionHeadsLitType; // Tokens involved in the attention. const inToks = (this.preds!)[fieldSpec.align_in!] as Tokens; const outToks = (this.preds!)[fieldSpec.align_out!] as Tokens; - const fontFamily = "'Share Tech Mono', monospace"; + const fontFamily = '\'Share Tech Mono\', monospace'; const fontSize = 12; const defaultCharWidth = 6.5; const font = `${fontSize}px ${fontFamily}`; diff --git a/lit_nlp/client/modules/classification_module.ts b/lit_nlp/client/modules/classification_module.ts index b4473d58..df075417 100644 --- a/lit_nlp/client/modules/classification_module.ts +++ b/lit_nlp/client/modules/classification_module.ts @@ -25,6 +25,7 @@ import {observable} from 'mobx'; import {app} from '../core/app'; import {LitModule} from '../core/lit_module'; import {ColumnHeader, SortableTemplateResult, TableData} from '../elements/table'; +import {MulticlassPreds} from '../lib/lit_types'; import {styles as sharedStyles} from '../lib/shared_styles.css'; import {IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys} from '../lib/utils'; @@ -130,7 +131,7 @@ export class ClassificationModule extends LitModule { const predClassKey = this.dataService.getColumnName( model, predKey, CalculatedColumnType.PREDICTED_CLASS); labeledPredictions[topLevelKey] = {}; - const {parent, vocab} = output[predKey]; + const {parent, vocab} = output[predKey] as MulticlassPreds; const scores = inputs.map(input => this.dataService.getVal(input.id, topLevelKey)); const predictedClasses = @@ -163,20 +164,25 @@ export class ClassificationModule extends LitModule { } override render() { - const hasGroundTruth = this.appState.currentModels.some(model => - Object.values(this.appState.currentModelSpecs[model].spec.output) - .some(feature => feature.parent != null)); + // TODO(b/162269499): Check that feature.parent is within the spec. + const hasGroundTruth = this.appState.currentModels.some( + model => + Object.values(this.appState.currentModelSpecs[model].spec.output) + .some( + feature => feature instanceof MulticlassPreds && + feature.parent != null)); return html`
- ${Object.entries(this.labeledPredictions) - .map(([fieldName, labelRow], i, arr) => { - const featureTable = - this.renderFeatureTable(labelRow, hasGroundTruth); - return arr.length === 1 ? featureTable: html` + ${ + Object.entries(this.labeledPredictions) + .map(([fieldName, labelRow], i, arr) => { + const featureTable = + this.renderFeatureTable(labelRow, hasGroundTruth); + return arr.length === 1 ? featureTable : html` ${featureTable} `; - })} + })}