From 2efe62b77dbfddf698b0a79408c0227bd21dc959 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 1 Apr 2024 10:12:01 -0700 Subject: [PATCH] Plot PR/ROC curves for slices. PiperOrigin-RevId: 620880580 --- lit_nlp/client/modules/curves_module.ts | 108 +++++++++++++++++++----- 1 file changed, 85 insertions(+), 23 deletions(-) diff --git a/lit_nlp/client/modules/curves_module.ts b/lit_nlp/client/modules/curves_module.ts index 27650837..93246468 100644 --- a/lit_nlp/client/modules/curves_module.ts +++ b/lit_nlp/client/modules/curves_module.ts @@ -20,7 +20,7 @@ import '../elements/expansion_panel'; import '../elements/line_chart'; import {html, TemplateResult} from 'lit'; -import {customElement} from 'lit/decorators.js'; +import {customElement, state} from 'lit/decorators.js'; import {action, computed, observable} from 'mobx'; import {app} from '../core/app'; @@ -31,7 +31,7 @@ import {styles as sharedStyles} from '../lib/shared_styles.css'; import {type GroupedExamples, IndexedInput, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; import {findSpecKeys, hasValidParent} from '../lib/utils'; import {NumericFeatureBins} from '../services/group_service'; -import {GroupService} from '../services/services'; +import {GroupService, SliceService} from '../services/services'; import {styles} from './curves_module.css'; @@ -50,6 +50,10 @@ interface CurvesData { rocCurve: Map; } +interface CurvesDataMap { + [name: string]: CurvesData[]; +} + /** * A LIT module that renders PR/ROC curves. */ @@ -72,9 +76,12 @@ export class CurvesModule extends LitModule { ]; } + @state() private showSlices = false; + @observable private readonly isPanelCollapsed = new Map(); @observable private datasetCurves?: CurvesData[]; - @observable private groupedCurves: {[group: string]: CurvesData[]} = {}; + @observable private sliceCurves: CurvesDataMap = {}; + @observable private groupedCurves: CurvesDataMap = {}; @observable private selectedPredKeyIndex = 0; @observable private selectedPositiveLabelIndex = 0; @observable private positiveLabelOptions: string[] = []; @@ -85,6 +92,7 @@ export class CurvesModule extends LitModule { @observable private readonly selectedFacets: string[] = []; private readonly groupService = app.getService(GroupService); + private readonly sliceService = app.getService(SliceService); // TODO(b/204677206): Using document.createElement() here may be inducing this // module to schedule an update while another update is already in progress. // Note that this was introduced in cl/463915592 in order to preserve the @@ -105,24 +113,49 @@ export class CurvesModule extends LitModule { override connectedCallback() { super.connectedCallback(); + this.reactImmediately( - () => [this.appState.currentInputData, this.predKey, - this.positiveLabel], async () => { - this.datasetCurves = await this.getCurveData( - this.appState.currentInputData, this.predKey, - this.positiveLabel); - }); + () => [ + this.appState.currentInputData, + this.predKey, + this.positiveLabel + ] as const, + async () => { + this.datasetCurves = await this.getCurveData( + this.appState.currentInputData, + this.predKey, + this.positiveLabel); + }); + + this.reactImmediately( + () => [this.groupedExamples, this.predKey, this.positiveLabel] as const, + async ([groupedExamples, predKey, positiveLabel]) => { + await this.getGroupedCurveData( + groupedExamples, predKey, positiveLabel + ); + }); + this.reactImmediately( - () => [this.groupedExamples, this.predKey, this.positiveLabel], + () => [this.appState.currentModels] as const, async () => { - await this.getGroupedCurveData(this.groupedExamples, this.predKey, - this.positiveLabel); + this.updateLabels(); + this.datasetCurves = await this.getCurveData( + this.appState.currentInputData, + this.predKey, + this.positiveLabel + ); + }); + + this.reactImmediately( + () => [ + this.sliceService.sliceNames, + this.sliceMembers, + this.predKey, + this.positiveLabel + ] as const, + async ([namedNames, sliceMembers, predKey, positiveLabel]) => { + await this.getSliceCurveData(namedNames, predKey, positiveLabel); }); - this.reactImmediately(() => [this.appState.currentModels], async () => { - this.updateLabels(); - this.datasetCurves = await this.getCurveData( - this.appState.currentInputData, this.predKey, this.positiveLabel); - }); } @action @@ -139,6 +172,22 @@ export class CurvesModule extends LitModule { } } + @action + private async getSliceCurveData( + slices: Iterable, + predKey: string, + positiveLabel: string + ) { + this.sliceCurves = {}; + for (const name of slices) { + this.sliceCurves[name] = await this.getCurveData( + this.sliceService.getSliceDataByName(name), + predKey, + positiveLabel + ); + } + } + private updateLabels() { let positiveLabelOptions: string[] = []; for (const modelName of this.appState.currentModels) { @@ -209,6 +258,12 @@ export class CurvesModule extends LitModule { return data; } + @computed private get sliceMembers(): string[] { + return [ + ...this.sliceService.namedSlices.values() + ].flatMap((sliceData) => [...sliceData.values()]); + } + @computed get predKeyOptions() { return this.appState.currentModels.flatMap((modelName: string) => { @@ -354,14 +409,18 @@ export class CurvesModule extends LitModule { return html` ${this.renderPredKeySelect()} ${this.renderPositiveLabelSelect()} + {this.showSlices = !this.showSlices;}} + ?disabled=${this.sliceService.sliceNames.length === 0}> + ${this.facetingControl} `; // clang-format on } override renderImpl() { - const groups = Object.keys(this.groupedCurves); - // clang-format off return html`
@@ -371,12 +430,15 @@ export class CurvesModule extends LitModule {
${this.datasetCurves ? this.renderCharts(this.datasetCurves, 'Dataset', true) : null} - ${groups.map(group => - this.renderCharts(this.groupedCurves[group], group, false) + ${this.showSlices ? + Object.entries(this.sliceCurves).map( + ([name, data]) => this.renderCharts(data, name, false)) : + null} + ${Object.entries(this.groupedCurves).map( + ([name, data]) => this.renderCharts(data, name, false) )}
-
- `; + `; // clang-format on }