diff --git a/lit_nlp/client/default/layout.ts b/lit_nlp/client/default/layout.ts index 25e1cae1..2fb823f9 100644 --- a/lit_nlp/client/default/layout.ts +++ b/lit_nlp/client/default/layout.ts @@ -29,6 +29,7 @@ import {CounterfactualExplainerModule} from '../modules/counterfactual_explainer import {DataTableModule, SimpleDataTableModule} from '../modules/data_table_module'; import {DatapointEditorModule, SimpleDatapointEditorModule} from '../modules/datapoint_editor_module'; import {EmbeddingsModule} from '../modules/embeddings_module'; +import {FeatureAttributionModule} from '../modules/feature_attribution_module'; import {GeneratedImageModule} from '../modules/generated_image_module'; import {GeneratedTextModule} from '../modules/generated_text_module'; import {GeneratorModule} from '../modules/generator_module'; @@ -47,7 +48,7 @@ import {TCAVModule} from '../modules/tcav_module'; import {ThresholderModule} from '../modules/thresholder_module'; // clang-format off -const MODEL_PREDS_MODULES: LitModuleType[] = [ +const MODEL_PREDS_MODULES: readonly LitModuleType[] = [ SpanGraphGoldModuleVertical, SpanGraphModuleVertical, ClassificationModule, @@ -60,7 +61,7 @@ const MODEL_PREDS_MODULES: LitModuleType[] = [ GeneratedImageModule, ]; -const DEFAULT_MAIN_GROUP: LitModuleType[] = [ +const DEFAULT_MAIN_GROUP: readonly LitModuleType[] = [ DataTableModule, DatapointEditorModule, SliceModule, @@ -108,6 +109,7 @@ export const LAYOUTS: LitComponentLayouts = { SalienceMapModule, SequenceSalienceModule, AttentionModule, + FeatureAttributionModule, ], 'Clustering': [SalienceClusteringModule], 'Metrics': [ diff --git a/lit_nlp/client/lib/utils.ts b/lit_nlp/client/lib/utils.ts index d61f7662..d667ef8f 100644 --- a/lit_nlp/client/lib/utils.ts +++ b/lit_nlp/client/lib/utils.ts @@ -24,6 +24,28 @@ import * as d3 from 'd3'; // Used for array helpers. import {html, TemplateResult} from 'lit'; import {FacetMap, LitName, LitType, ModelInfoMap, Spec} from './types'; +/** Calculates the mean for a list of numbers */ +export function mean(values: number[]): number { + return values.reduce((a, b) => a + b) / values.length; +} + +/** Calculates the median for a list of numbers. */ +export function median(values: number[]): number { + const sorted = [...values].sort(); + const medIdx = Math.floor(sorted.length / 2); + let median: number; + + if (sorted.length % 2 === 0) { + const upper = sorted[medIdx]; + const lower = sorted[medIdx - 1]; + median = (upper + lower) / 2; + } else { + median = sorted[medIdx]; + } + + return median; +} + /** * Random integer in range [min, max), where min and max are integers * (behavior on floats is undefined). diff --git a/lit_nlp/client/lib/utils_test.ts b/lit_nlp/client/lib/utils_test.ts index d59e7c20..b364ee29 100644 --- a/lit_nlp/client/lib/utils_test.ts +++ b/lit_nlp/client/lib/utils_test.ts @@ -23,6 +23,25 @@ import 'jasmine'; import {Spec} from '../lib/types'; import * as utils from './utils'; +describe('mean test', () => { + it('computes a mean', () => { + const values = [1,3,2,5,4]; + expect(utils.mean(values)).toEqual(3); + }); +}); + +describe('median test', () => { + it('computes a median for a list of integers with odd length', () => { + const values = [1,3,2,5,4]; + expect(utils.median(values)).toEqual(3); + }); + + it('computes a median for a list of integers with even length', () => { + const values = [1,3,2,5,4,6]; + expect(utils.median(values)).toEqual(3.5); + }); +}); + describe('randInt test', () => { it('generates random integers in a given range', async () => { let start = 1; diff --git a/lit_nlp/client/modules/feature_attribution_module.css b/lit_nlp/client/modules/feature_attribution_module.css new file mode 100644 index 00000000..8c08a7f6 --- /dev/null +++ b/lit_nlp/client/modules/feature_attribution_module.css @@ -0,0 +1,25 @@ +.module-toolbar { + border-bottom: 1px solid var(--lit-neutral-300); +} + +.attribution-container { + border-bottom: 1px solid var(--lit-neutral-300); +} + +.expansion-header { + width: calc(100% - 16px); + height: 30px; + padding: 2px 8px; + display: flex; + flex-direction: row; + cursor: pointer; + align-items: center; + justify-content: space-between; + background-color: var(--lit-neutral-100); + color: var(--lit-majtonal-nv-800); + font-weight: bold; +} + +.expansion-header:not(:only-child) { + border-bottom: 1px solid var(--lit-neutral-200); +} diff --git a/lit_nlp/client/modules/feature_attribution_module.ts b/lit_nlp/client/modules/feature_attribution_module.ts new file mode 100644 index 00000000..0070fdbc --- /dev/null +++ b/lit_nlp/client/modules/feature_attribution_module.ts @@ -0,0 +1,351 @@ +/** + * @fileoverview Feature attribution info for tabular ML models. + * + * @license + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// tslint:disable:no-new-decorators +import {html} from 'lit'; +import {customElement} from 'lit/decorators'; +import {computed, observable} from 'mobx'; + +import {app} from '../core/app'; +import {FacetsChange} from '../core/faceting_control'; +import {LitModule} from '../core/lit_module'; +import {TableData} from '../elements/table'; +import {IndexedInput, ModelInfoMap} from '../lib/types'; +import * as utils from '../lib/utils'; +import {findSpecKeys} from '../lib/utils'; +import {AppState, GroupService} from '../services/services'; +import {NumericFeatureBins} from '../services/group_service'; + +import {styles as sharedStyles} from '../lib/shared_styles.css'; +import {styles} from './feature_attribution_module.css'; + +const ALL_DATA = 'Entire Dataset'; + +interface AttributionStats { + min: number; + median: number; + max: number; + mean: number; +} + +interface AttributionStatsMap { + [feature: string]: AttributionStats; +} + +interface FeatureSalience { + salience: {[feature: string]: number}; +} + +interface FeatureSalienceResult { + [key: string]: FeatureSalience; +} + +interface SummariesMap { + [facet: string]: AttributionStatsMap; +} + +interface VisToggles { + [name: string]: boolean; +} + +/** Aggregate feature attribution for tabular ML models. */ +@customElement('feature-attribution-module') +export class FeatureAttributionModule extends LitModule { + // ---- Static Properties and API ---- + static override title = 'Tabular Feature Attribution'; + static override duplicateForExampleComparison = false; + static override duplicateAsRow = true; + + static override get styles() { + return [sharedStyles, styles]; + } + + static override shouldDisplayModule(modelSpecs: ModelInfoMap) { + const appState = app.getService(AppState); + if (appState.metadata == null) return false; + + return Object.values(modelSpecs).some(modelInfo => { + // The model directly outputs FeatureSalience + const hasIntrinsicSalience = + findSpecKeys(modelInfo.spec.output, 'FeatureSalience').length > 0; + + // At least one compatible interpreter outputs FeatureSalience + const canDeriveSalience = modelInfo.interpreters.some(name => { + const {metaSpec} = appState.metadata.interpreters[name]; + return findSpecKeys(metaSpec, 'FeatureSalience').length > 0; + }); + + return hasIntrinsicSalience || canDeriveSalience; + }); + } + + static override template(model = '') { + // clang-format off + return html` + `; + // clang format on + } + + // ---- Instance Properties ---- + + private readonly groupService = app.getService(GroupService); + + @observable private readonly expanded: VisToggles = {[ALL_DATA]: true}; + @observable private features: string[] = []; + @observable private bins: NumericFeatureBins = {}; + @observable private summaries: SummariesMap = {}; + @observable private readonly enabled: VisToggles = { + 'model': this.hasIntrinsicSalience + }; + + @computed + private get facets() { + return this.groupService.groupExamplesByFeatures( + this.bins, this.appState.currentInputData, this.features); + } + + @computed + private get hasIntrinsicSalience() { + if (this.appState.metadata.models[this.model]?.spec?.output) { + return findSpecKeys(this.appState.metadata.models[this.model].spec.output, + 'FeatureSalience').length > 0; + } + return false; + } + + @computed + private get salienceInterpreters() { + const {interpreters} = this.appState.metadata.models[this.model]; + return Object.entries(app.getService(AppState).metadata.interpreters) + .filter(([name, i]) => + interpreters.includes(name) && + findSpecKeys(i.metaSpec,'FeatureSalience')) + .map(([name]) => name); + } + + // ---- Private API ---- + + /** + * Retrieves and summarizes `FeatureSalience` info from the model predictions + * for the named facet (i.e., subset of data), and adds the summaries to the + * module's state. + * + * Models may provide `FeatureSalience` data in multiple output features. + * Summmaries are created and stored on state for each feature-facet pair. + */ + private async predict(facet: string, data: IndexedInput[]) { + const promise = this.apiService.getPreds( + data, this.model, this.appState.currentDataset, ['FeatureSalience']); + const results = await this.loadLatest('predictionScores', promise); + + if (results == null) return; + + const outputSpec = this.appState.metadata.models[this.model].spec.output; + const salienceKeys = findSpecKeys(outputSpec, 'FeatureSalience'); + + for (const key of salienceKeys) { + const summaryName = `Feature: ${key} | Facet: ${facet}`; + const values = results.map(res => res[key]) as FeatureSalience[]; + const summary = this.summarize(values); + if (summary) this.summaries[summaryName] = summary; + this.expanded[summaryName] = Object.keys(this.summaries).length === 1; + } + } + + /** Updates salience summaries provided by the model. */ + private async updateModelAttributions() { + if (!this.enabled['model']) return; + + await this.predict(ALL_DATA, this.appState.currentInputData); + + if (this.features.length) { + for (const [facet, group] of Object.entries(this.facets)) { + await this.predict(facet, group.data); + } + } + } + + /** + * Retrieves and summarizes `FeatureSalience` info from the given interpreter + * for the named facet (i.e., subset of data), and adds the summaries to the + * module's state. + * + * Interpreters may provide `FeatureSalience` data in multiple output fields. + * Summmaries are created and stored on state for each field-facet pair. + */ + private async interpret(name: string, facet: string, data: IndexedInput[]) { + const runKey = `interpretations-${name}`; + const promise = this.apiService.getInterpretations( + data, this.model, this.appState.currentDataset, name, {}, + `Running ${name}`); + const results = + (await this.loadLatest(runKey, promise)) as FeatureSalienceResult[]; + + if (results == null || results.length === 0) return; + + // TODO(b/217724273): figure out a more elegant way to handle + // variable-named output fields with metaSpec. + const {metaSpec} = this.appState.metadata.interpreters[name]; + const {output} = this.appState.getModelSpec(this.model); + const spec = {...metaSpec, ...output}; + const salienceKeys = findSpecKeys(spec, 'FeatureSalience'); + + for (const key of salienceKeys) { + if (results[0][key] != null) { + const salience = results.map((a: FeatureSalienceResult) => a[key]); + const summaryName = + `Interpreter: ${name} | Key: ${key} | Facet: ${facet}`; + this.summaries[summaryName] = this.summarize(salience); + this.expanded[summaryName] = Object.keys(this.summaries).length === 1; + } + } + } + + /** Updates salience summaries for all enabled interpreters. */ + private async updateInterpreterAttributions() { + const interpreters = this.salienceInterpreters.filter(i => this.enabled[i]); + for (const interpreter of interpreters) { + await this.interpret(interpreter, ALL_DATA, + this.appState.currentInputData); + + if (this.features.length) { + for (const [facet, group] of Object.entries(this.facets)) { + await this.interpret(interpreter, facet, group.data); + } + } + } + } + + private updateSummaries() { + this.summaries = {}; + this.updateModelAttributions(); + this.updateInterpreterAttributions(); + } + + /** + * Summarizes the distribution of feature attribution values of a dataset + */ + private summarize(data: FeatureSalience[]) { + const statsMap: AttributionStatsMap = {}; + const fields = Object.keys(data[0].salience); + + for (const field of fields) { + const values = data.map(d => d.salience[field]); + const min = Math.min(...values); + const max = Math.max(...values); + const mean = utils.mean(values); + const median = utils.median(values); + statsMap[field] = {min, median, max, mean}; + } + + return statsMap; + } + + private renderFacetControls() { + const updateFacets = (event: CustomEvent) => { + this.features = event.detail.features; + this.bins = event.detail.bins; + }; + + // clang-format off + return html` + `; + // clang-format on + } + + private renderSalienceControls() { + const change = (name: string) => { + this.enabled[name] = !this.enabled[name]; + }; + // clang-format off + return html` + Show attributions from: + ${this.hasIntrinsicSalience ? + html` {change('model');}}> + ` : null} + ${this.salienceInterpreters.map(interp => + html` {change(interp);}}> + `)}`; + // clang-format on + } + + private renderTable(summary: AttributionStatsMap) { + const columnNames = ['field', 'min', 'median', 'max', 'mean']; + const tableData: TableData[] = + Object.entries(summary).map(([feature, stats]) => { + const {min, median, max, mean} = stats; + return [feature, min, median, max, mean]; + }); + + // clang-format off + return html``; + // clang-format on + } + + // ---- Public API ---- + + override firstUpdated() { + const dataChange = () => [this.appState.currentInputData, this.features, + this.model, Object.values(this.enabled)]; + this.react(dataChange, () => {this.updateSummaries();}); + + this.enabled['model'] = this.hasIntrinsicSalience; + this.updateSummaries(); + } + + override render() { + // clang-format off + return html` +
+
${this.renderSalienceControls()}
+
${this.renderFacetControls()}
+
+ ${Object.entries(this.summaries) + .sort() + .map(([facet, summary]) => { + const toggle = () => { + this.expanded[facet] = !this.expanded[facet]; + }; + return html` +
+
+
${facet}
+ + ${this.expanded[facet] ? 'expand_less' : 'expand_more'} + +
+ ${this.expanded[facet] ? this.renderTable(summary) : ''} +
`; + })} +
+
`; + // clang-format on + } +} + +declare global { + interface HTMLElementTagNameMap { + 'feature-attribution-module': FeatureAttributionModule; + } +} diff --git a/lit_nlp/client/modules/salience_map_module.ts b/lit_nlp/client/modules/salience_map_module.ts index 86996753..8f0d2ddb 100644 --- a/lit_nlp/client/modules/salience_map_module.ts +++ b/lit_nlp/client/modules/salience_map_module.ts @@ -59,11 +59,14 @@ interface FeatureSalienceResult { [key: string]: {salience: FeatureSalienceMap}; } +type SalienceResult = TokenSalienceResult | ImageSalienceResult | + FeatureSalienceResult; + /** * UI status for each interpreter. */ interface InterpreterState { - salience: TokenSalienceResult|ImageSalienceResult|FeatureSalienceResult; + salience: SalienceResult; autorun: boolean; isLoading: boolean; cmap: SalienceCmap; @@ -280,8 +283,7 @@ export class SalienceMapModule extends LitModule {
- - `; + `; // clang-format on } @@ -309,19 +311,16 @@ export class SalienceMapModule extends LitModule {
- - `; + `; // clang-format on } - renderGroup( - salience: TokenSalienceResult|ImageSalienceResult|FeatureSalienceResult, - gradKey: string, cmap: SalienceCmap) { - const spec = this.appState.getModelSpec(this.model); - if (isLitSubtype(spec.output[gradKey], 'ImageGradients')) { + renderGroup(salience: SalienceResult, spec: Spec, gradKey: string, + cmap: SalienceCmap) { + if (isLitSubtype(spec[gradKey], 'ImageGradients')) { salience = salience as ImageSalienceResult; return this.renderImage(salience, gradKey); - } else if (isLitSubtype(spec.output[gradKey], 'FeatureSalience')) { + } else if (isLitSubtype(spec[gradKey], 'FeatureSalience')) { salience = salience as FeatureSalienceResult; return this.renderFeatureSalience(salience, gradKey, cmap); } else { @@ -345,12 +344,9 @@ export class SalienceMapModule extends LitModule { this.state[name].autorun = !this.state[name].autorun; }; // clang-format off - return html` - - - `; + return html` + `; // clang-format on }); } @@ -379,21 +375,25 @@ export class SalienceMapModule extends LitModule { } } return html` - + `; }; + return html` ${Object.keys(this.state).map(name => { if (!this.state[name].autorun) { return null; } + // TODO(b/217724273): figure out a more elegant way to handle + // variable-named output fields with metaSpec. + const {output} = this.appState.getModelSpec(this.model); + const {metaSpec} = this.appState.metadata.interpreters[name]; + const spec = {...metaSpec, ...output}; const salience = this.state[name].salience; const description = - this.appState.metadata.interpreters[name].description || name; + this.appState.metadata.interpreters[name].description ?? name; return html` @@ -433,21 +434,15 @@ export class SalienceMapModule extends LitModule { // Ensure there are salience interpreters for loaded models. const appState = app.getService(AppState); - for (const modelInfo of Object.values(modelSpecs)) { - for (let i = 0; i < modelInfo.interpreters.length; i++) { - const interpreterName = modelInfo.interpreters[i]; - if (appState.metadata == null) { - return false; - } - const interpreter = appState.metadata.interpreters[interpreterName]; - const salienceKeys = findSpecKeys( - interpreter.metaSpec, SalienceMapModule.salienceTypes); - if (salienceKeys.length !== 0) { - return true; - } - } - } - return false; + if (appState.metadata == null) return false; + + return Object.values(modelSpecs).some(modelInfo => + modelInfo.interpreters.some(name => { + const interpreter = appState.metadata.interpreters[name]; + const salienceKeys = findSpecKeys(interpreter.metaSpec, + SalienceMapModule.salienceTypes); + return salienceKeys.length > 0; + })); } }
@@ -402,7 +402,8 @@ export class SalienceMapModule extends LitModule { ${Object.keys(salience).map(gradKey => - this.renderGroup(salience, gradKey, this.state[name].cmap))} + this.renderGroup(salience, spec, gradKey, + this.state[name].cmap))} ${this.state[name].isLoading ? this.renderSpinner() : null}