Skip to content

Commit

Permalink
Plot PR/ROC curves for slices.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620880580
  • Loading branch information
RyanMullins authored and LIT team committed Apr 1, 2024
1 parent 3d61a09 commit 2efe62b
Showing 1 changed file with 85 additions and 23 deletions.
108 changes: 85 additions & 23 deletions lit_nlp/client/modules/curves_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import '../elements/expansion_panel';
import '../elements/line_chart';

import {html, TemplateResult} from 'lit';
import {customElement} from 'lit/decorators.js';
import {customElement, state} from 'lit/decorators.js';
import {action, computed, observable} from 'mobx';

import {app} from '../core/app';
Expand All @@ -31,7 +31,7 @@ import {styles as sharedStyles} from '../lib/shared_styles.css';
import {type GroupedExamples, IndexedInput, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types';
import {findSpecKeys, hasValidParent} from '../lib/utils';
import {NumericFeatureBins} from '../services/group_service';
import {GroupService} from '../services/services';
import {GroupService, SliceService} from '../services/services';

import {styles} from './curves_module.css';

Expand All @@ -50,6 +50,10 @@ interface CurvesData {
rocCurve: Map<number, number>;
}

interface CurvesDataMap {
[name: string]: CurvesData[];
}

/**
* A LIT module that renders PR/ROC curves.
*/
Expand All @@ -72,9 +76,12 @@ export class CurvesModule extends LitModule {
];
}

@state() private showSlices = false;

@observable private readonly isPanelCollapsed = new Map();
@observable private datasetCurves?: CurvesData[];
@observable private groupedCurves: {[group: string]: CurvesData[]} = {};
@observable private sliceCurves: CurvesDataMap = {};
@observable private groupedCurves: CurvesDataMap = {};
@observable private selectedPredKeyIndex = 0;
@observable private selectedPositiveLabelIndex = 0;
@observable private positiveLabelOptions: string[] = [];
Expand All @@ -85,6 +92,7 @@ export class CurvesModule extends LitModule {
@observable private readonly selectedFacets: string[] = [];

private readonly groupService = app.getService(GroupService);
private readonly sliceService = app.getService(SliceService);
// TODO(b/204677206): Using document.createElement() here may be inducing this
// module to schedule an update while another update is already in progress.
// Note that this was introduced in cl/463915592 in order to preserve the
Expand All @@ -105,24 +113,49 @@ export class CurvesModule extends LitModule {

override connectedCallback() {
super.connectedCallback();

this.reactImmediately(
() => [this.appState.currentInputData, this.predKey,
this.positiveLabel], async () => {
this.datasetCurves = await this.getCurveData(
this.appState.currentInputData, this.predKey,
this.positiveLabel);
});
() => [
this.appState.currentInputData,
this.predKey,
this.positiveLabel
] as const,
async () => {
this.datasetCurves = await this.getCurveData(
this.appState.currentInputData,
this.predKey,
this.positiveLabel);
});

this.reactImmediately(
() => [this.groupedExamples, this.predKey, this.positiveLabel] as const,
async ([groupedExamples, predKey, positiveLabel]) => {
await this.getGroupedCurveData(
groupedExamples, predKey, positiveLabel
);
});

this.reactImmediately(
() => [this.groupedExamples, this.predKey, this.positiveLabel],
() => [this.appState.currentModels] as const,
async () => {
await this.getGroupedCurveData(this.groupedExamples, this.predKey,
this.positiveLabel);
this.updateLabels();
this.datasetCurves = await this.getCurveData(
this.appState.currentInputData,
this.predKey,
this.positiveLabel
);
});

this.reactImmediately(
() => [
this.sliceService.sliceNames,
this.sliceMembers,
this.predKey,
this.positiveLabel
] as const,
async ([namedNames, sliceMembers, predKey, positiveLabel]) => {
await this.getSliceCurveData(namedNames, predKey, positiveLabel);
});
this.reactImmediately(() => [this.appState.currentModels], async () => {
this.updateLabels();
this.datasetCurves = await this.getCurveData(
this.appState.currentInputData, this.predKey, this.positiveLabel);
});
}

@action
Expand All @@ -139,6 +172,22 @@ export class CurvesModule extends LitModule {
}
}

@action
private async getSliceCurveData(
slices: Iterable<string>,
predKey: string,
positiveLabel: string
) {
this.sliceCurves = {};
for (const name of slices) {
this.sliceCurves[name] = await this.getCurveData(
this.sliceService.getSliceDataByName(name),
predKey,
positiveLabel
);
}
}

private updateLabels() {
let positiveLabelOptions: string[] = [];
for (const modelName of this.appState.currentModels) {
Expand Down Expand Up @@ -209,6 +258,12 @@ export class CurvesModule extends LitModule {
return data;
}

@computed private get sliceMembers(): string[] {
return [
...this.sliceService.namedSlices.values()
].flatMap((sliceData) => [...sliceData.values()]);
}

@computed
get predKeyOptions() {
return this.appState.currentModels.flatMap((modelName: string) => {
Expand Down Expand Up @@ -354,14 +409,18 @@ export class CurvesModule extends LitModule {
return html`
${this.renderPredKeySelect()}
${this.renderPositiveLabelSelect()}
<lit-checkbox
label="Show slices"
?checked=${this.showSlices}
@change=${() => {this.showSlices = !this.showSlices;}}
?disabled=${this.sliceService.sliceNames.length === 0}>
</lit-checkbox>
${this.facetingControl}
`;
// clang-format on
}

override renderImpl() {
const groups = Object.keys(this.groupedCurves);

// clang-format off
return html`
<div class='module-container'>
Expand All @@ -371,12 +430,15 @@ export class CurvesModule extends LitModule {
<div class='module-results-area ${SCROLL_SYNC_CSS_CLASS}'>
${this.datasetCurves ?
this.renderCharts(this.datasetCurves, 'Dataset', true) : null}
${groups.map(group =>
this.renderCharts(this.groupedCurves[group], group, false)
${this.showSlices ?
Object.entries(this.sliceCurves).map(
([name, data]) => this.renderCharts(data, name, false)) :
null}
${Object.entries(this.groupedCurves).map(
([name, data]) => this.renderCharts(data, name, false)
)}
</div>
</div>
`;
</div>`;
// clang-format on
}

Expand Down

0 comments on commit 2efe62b

Please sign in to comment.