From 5e91b1984700f6c1bb25b05d25e091d8d522c7e9 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Thu, 22 Sep 2022 05:52:41 -0700 Subject: [PATCH] Adds a threshold property to MulticlassPreds type. Updates ClassificationService to respect MulticlassPreds field specs when resetting values. PiperOrigin-RevId: 476073578 --- lit_nlp/api/types.py | 1 + lit_nlp/client/elements/line_chart.ts | 1 - lit_nlp/client/lib/lit_types.ts | 2 + lit_nlp/client/lib/utils.ts | 15 +- .../client/services/classification_service.ts | 79 ++++---- .../services/classification_service_test.ts | 177 ++++++++++++++++++ lit_nlp/examples/models/glue_models.py | 21 ++- 7 files changed, 245 insertions(+), 51 deletions(-) create mode 100644 lit_nlp/client/services/classification_service_test.ts diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py index b38b01d4..c5b11d24 100644 --- a/lit_nlp/api/types.py +++ b/lit_nlp/api/types.py @@ -278,6 +278,7 @@ class MulticlassPreds(_Tensor1D): null_idx: Optional[int] = None # vocab index of negative (null) label parent: Optional[str] = None # CategoryLabel field in input autosort: Optional[bool] = False # Enable automatic sorting + threshold: Optional[float] = None # binary threshold, used to compute margin @property def num_labels(self): diff --git a/lit_nlp/client/elements/line_chart.ts b/lit_nlp/client/elements/line_chart.ts index 48d59bf6..69dd6af9 100644 --- a/lit_nlp/client/elements/line_chart.ts +++ b/lit_nlp/client/elements/line_chart.ts @@ -130,7 +130,6 @@ export class LineChart extends ReactiveElement { .attr("stroke", 'var(--lit-cyea-400)'); const mousemove = () => { - console.log(d3.mouse(this)); const xLocation = d3.mouse(this)[0] - this.margin; const x0 = x.invert(xLocation); const bisect = d3.bisect(data.map(data => data[0]), x0); diff --git a/lit_nlp/client/lib/lit_types.ts b/lit_nlp/client/lib/lit_types.ts index 04dd4cab..16654108 100644 --- a/lit_nlp/client/lib/lit_types.ts +++ b/lit_nlp/client/lib/lit_types.ts @@ -264,6 +264,8 @@ export class MulticlassPreds extends _Tensor1D { parent?: string = undefined; /** Enable automatic sorting. */ autosort?: boolean = false; + /** Binary threshold, used to compute margin. */ + threshold?: number = undefined; get num_labels() { return this.vocab.length; diff --git a/lit_nlp/client/lib/utils.ts b/lit_nlp/client/lib/utils.ts index addfd885..0719fbdc 100644 --- a/lit_nlp/client/lib/utils.ts +++ b/lit_nlp/client/lib/utils.ts @@ -293,20 +293,15 @@ export function handleEnterKey(e: KeyboardEvent, callback: () => void) { * Converts the margin value to the threshold for binary classification. */ export function getThresholdFromMargin(margin: number) { - if (margin == null) { - return .5; - } - return margin === 0 ? .5 : 1 / (1 + Math.exp(-margin)); + return !margin ? .5 : 1 / (1 + Math.exp(-margin)); } /** * Converts the threshold value for binary classification to the margin. */ export function getMarginFromThreshold(threshold: number) { - const margin = threshold !== 1 ? - (threshold !== 0 ? Math.log(threshold / (1 - threshold)) : -5) : - 5; - return margin; + return threshold === 1 ? 5 : + threshold === 0 ? -5 : Math.log(threshold / (1 - threshold)); } /** @@ -485,7 +480,7 @@ export function copyToClipboard(value: string) { * NPWS will make copy/pasting from the table behave strangely. */ export function chunkWords(sent: string) { - const chunkWord = (word: string) => { + function chunkWord (word: string) { const maxLen = 15; const chunks: string[] = []; for (let i=0; i chunkWord(word)).join(' '); } diff --git a/lit_nlp/client/services/classification_service.ts b/lit_nlp/client/services/classification_service.ts index 513fa344..c74da0b0 100644 --- a/lit_nlp/client/services/classification_service.ts +++ b/lit_nlp/client/services/classification_service.ts @@ -20,6 +20,7 @@ import {action, computed, observable, reaction} from 'mobx'; import {MulticlassPreds} from '../lib/lit_types'; import {FacetedData, GroupedExamples, Spec} from '../lib/types'; +import {getMarginFromThreshold} from '../lib/utils'; import {LitService} from './lit_service'; import {AppState} from './state_service'; @@ -80,9 +81,11 @@ export class ClassificationService extends LitService { }, {fireImmediately: true}); } - // Returns all margin settings for use as a reaction input function when - // setting up observers. - // TODO(lit-team): Remove need for this intermediate object (b/156100081) + /** + * Returns all margin settings for use as a reaction input function when + * setting up observers. + */ + // TODO(b/156100081): Remove need for this intermediate object @computed get allMarginSettings(): number[] { const res: number[] = []; @@ -106,10 +109,17 @@ export class ClassificationService extends LitService { if (this.marginSettings[model] == null) { this.marginSettings[model] = {}; } - this.marginSettings[model][fieldName] = {}; - for (const group of Object.values(groupedExamples)) { - this.marginSettings[model][fieldName][group.displayName!] = - {facetData: group, margin: 0}; + if (this.marginSettings[model][fieldName] == null) { + this.marginSettings[model][fieldName] = {}; + } + const {output} = this.appState.currentModelSpecs[model].spec; + const fieldSpec = output[fieldName]; + if (!(fieldSpec instanceof MulticlassPreds)) return; + const margin = fieldSpec.threshold != null ? + getMarginFromThreshold(fieldSpec.threshold) : 0; + for (const facetData of Object.values(groupedExamples)) { + this.marginSettings[model][fieldName][facetData.displayName!] = + {facetData, margin}; } } @@ -120,20 +130,20 @@ export class ClassificationService extends LitService { for (const [model, output] of Object.entries(modelOutputSpecMap)) { marginSettings[model] = {}; for (const [fieldName, fieldSpec] of Object.entries(output)) { - if (fieldSpec instanceof MulticlassPreds && - fieldSpec.null_idx != null && fieldSpec.vocab != null) { - marginSettings[model][fieldName] = {}; - - if (model in this.marginSettings && - this.marginSettings[model][fieldName] != null) { - // Reset all facets to margin = 0. - const facets = Object.keys(this.marginSettings[model][fieldName]); - for (const key of facets) { - marginSettings[model][fieldName][key] = {margin: 0}; - } - } - - marginSettings[model][fieldName][GLOBAL_FACET] = {margin: 0}; + if (!(fieldSpec instanceof MulticlassPreds) || + fieldSpec.null_idx == null || !fieldSpec.vocab.length) continue; + + const margin = fieldSpec.threshold != null ? + getMarginFromThreshold(fieldSpec.threshold) : 0; + marginSettings[model][fieldName] = {[GLOBAL_FACET]: {margin}}; + + if (this.marginSettings[model] == null || + this.marginSettings[model][fieldName] == null) continue; + + const facets = Object.keys(this.marginSettings[model][fieldName]); + for (const facet of facets) { + const {facetData} = this.marginSettings[model][fieldName][facet]; + marginSettings[model][fieldName][facet] = {facetData, margin}; } } } @@ -142,28 +152,28 @@ export class ClassificationService extends LitService { } @action - setMargin(model: string, fieldName: string, value: number, - facet?: FacetedData) { + setMargin(model: string, fieldName: string, margin: number, + facetData?: FacetedData) { if (this.marginSettings[model] == null) { this.marginSettings[model] = {}; } if (this.marginSettings[model][fieldName] == null) { this.marginSettings[model][fieldName] = {}; } - if (facet == null) { + if (facetData == null) { // If no facet provided, then update the facet for the entire dataset // if one exists, otherwise update all facets with the provided margin. if (GLOBAL_FACET in this.marginSettings[model][fieldName]) { this.marginSettings[model][fieldName][GLOBAL_FACET] = - {facetData: facet, margin: value}; + {facetData, margin}; } else { for (const key of Object.keys(this.marginSettings[model][fieldName])) { - this.marginSettings[model][fieldName][key].margin = value; + this.marginSettings[model][fieldName][key].margin = margin; } } } else { - this.marginSettings[model][fieldName][facet.displayName!] = - {facetData: facet, margin: value}; + this.marginSettings[model][fieldName][facetData.displayName!] = + {facetData, margin}; } } @@ -172,16 +182,13 @@ export class ClassificationService extends LitService { this.marginSettings[model][fieldName] == null) { return 0; } + const fieldMargins = this.marginSettings[model][fieldName]; if (facet == null) { - if (this.marginSettings[model][fieldName][GLOBAL_FACET] == null) { - return 0; - } - return this.marginSettings[model][fieldName][GLOBAL_FACET].margin; + return fieldMargins[GLOBAL_FACET]?.margin || 0; + } else if (facet.displayName != null) { + return fieldMargins[facet.displayName]?.margin || 0; } else { - if (this.marginSettings[model][fieldName][facet.displayName!] == null) { - return 0; - } - return this.marginSettings[model][fieldName][facet.displayName!].margin; + return 0; } } } diff --git a/lit_nlp/client/services/classification_service_test.ts b/lit_nlp/client/services/classification_service_test.ts new file mode 100644 index 00000000..d81affa1 --- /dev/null +++ b/lit_nlp/client/services/classification_service_test.ts @@ -0,0 +1,177 @@ +import 'jasmine'; +import {MulticlassPreds} from '../lib/lit_types'; +import {getMarginFromThreshold} from '../lib/utils'; +import {GroupedExamples, ModelSpec} from '../lib/types'; +import {AppState} from './state_service'; +import {ClassificationService} from './classification_service'; + +const FIELD_NAME = 'pred'; +const MODEL_NAME = 'test_model'; + +const MULTICLASS_PRED_WITH_THRESHOLD = new MulticlassPreds(); +MULTICLASS_PRED_WITH_THRESHOLD.null_idx = 0; +MULTICLASS_PRED_WITH_THRESHOLD.vocab = ['0', '1']; +MULTICLASS_PRED_WITH_THRESHOLD.threshold = 0.3; +const MULTICLASS_SPEC_WITH_THRESHOLD: ModelSpec = { + input: {}, + output: {[FIELD_NAME]: MULTICLASS_PRED_WITH_THRESHOLD} +}; + +const MULTICLASS_PRED_WITHOUT_THRESHOLD = new MulticlassPreds(); +MULTICLASS_PRED_WITHOUT_THRESHOLD.null_idx = 0; +MULTICLASS_PRED_WITHOUT_THRESHOLD.vocab = ['0', '1']; +const MULTICLASS_SPEC_WITHOUT_THRESHOLD: ModelSpec = { + input: {}, + output: {[FIELD_NAME]: MULTICLASS_PRED_WITHOUT_THRESHOLD} +}; + +const MULTICLASS_PRED_NO_VOCAB = new MulticlassPreds(); +MULTICLASS_PRED_NO_VOCAB.null_idx = 0; +const INVALID_SPEC_NO_VOCAB: ModelSpec = { + input: {}, + output: {[FIELD_NAME]: MULTICLASS_PRED_NO_VOCAB} +}; + +const MULTICLASS_PRED_NO_NULL_IDX = new MulticlassPreds(); +MULTICLASS_PRED_NO_NULL_IDX.vocab = ['0', '1']; +const INVALID_SPEC_NO_NULL_IDX: ModelSpec = { + input: {}, + output: {[FIELD_NAME]: MULTICLASS_PRED_NO_NULL_IDX} +}; + +const INVALID_SPEC_NO_MULTICLASS_PRED: ModelSpec = { + input: {}, + output: {} +}; + +const UPDATED_MARGIN = getMarginFromThreshold(0.8); + +type MinimalAppState = Pick; + +describe('classification service test', () => { + [ // Parameterized tests for models with valid specs. + { + name: 'without a threshold', + spec: MULTICLASS_SPEC_WITHOUT_THRESHOLD, + facets: undefined, + expThreshold: undefined, + expMargin: 0 + }, + { + name: 'without a threshold with facets', + spec: MULTICLASS_SPEC_WITHOUT_THRESHOLD, + facets: ['TN', 'TP'], + expThreshold: undefined, + expMargin: 0 + }, + { + name: 'with a threshold', + spec: MULTICLASS_SPEC_WITH_THRESHOLD, + facets: undefined, + expThreshold: 0.3, + expMargin: getMarginFromThreshold(0.3) + }, + { + name: 'with a threshold and facets', + spec: MULTICLASS_SPEC_WITH_THRESHOLD, + facets: ['TN', 'TP'], + expThreshold: 0.3, + expMargin: getMarginFromThreshold(0.3) + }, + ].forEach(({name, spec, facets, expThreshold, expMargin}) => { + const mockAppState: MinimalAppState = { + currentModels: [MODEL_NAME], + currentModelSpecs: {[MODEL_NAME]: { + spec, + datasets: [], + generators: [], + interpreters: [] + }} + }; + + const classificationService = + new ClassificationService(mockAppState as {} as AppState); + + function getMargin (facet?: string) { + const facetData = facet != null ? {displayName: facet, data: []} : undefined; + return classificationService.getMargin(MODEL_NAME, FIELD_NAME, facetData); + } + + it(`derives margin settings from a spec ${name}`, () => { + const predSpec = spec.output['pred']; + expect(predSpec).toBeInstanceOf(MulticlassPreds); + expect((predSpec as MulticlassPreds).threshold).toEqual(expThreshold); + expect(getMargin()).toBe(expMargin); + }); + + it(`updates margin settings for specs ${name}`, () => { + classificationService.setMargin(MODEL_NAME, FIELD_NAME, UPDATED_MARGIN); + expect(getMargin()).toBe(UPDATED_MARGIN); + }); + + it(`resets margin settings for specs ${name}`, () => { + classificationService.resetMargins({[MODEL_NAME]: spec.output}); + expect(getMargin()).toBe(expMargin); + }); + + if (facets != null) { + const groupedExamples = facets.reduce((obj, facet) => { + obj[facet] = {data: [], displayName: facet}; + return obj; + }, {} as GroupedExamples); + + classificationService.setMarginGroups(MODEL_NAME, FIELD_NAME, + groupedExamples); + + for (const facet of facets) { + it(`derives margin for ${facet} facet from a spec ${name}`, () => { + expect(getMargin(facet)).toBe(expMargin); + }); + + it(`updates margin for ${facet} facet from a spec ${name}`, () => { + classificationService.setMargin(MODEL_NAME, FIELD_NAME, UPDATED_MARGIN, + {displayName: facet, data: []}); + expect(getMargin(facet)).toBe(UPDATED_MARGIN); + }); + + it(`resets margin for ${facet} facet from a spec ${name}`, () => { + classificationService.resetMargins({[MODEL_NAME]: spec.output}); + expect(getMargin(facet)).toBe(expMargin); + }); + } + } + }); + + [ // Parameterized tests for models with invalid specs + { + name: 'without a multiclass pred', + spec: INVALID_SPEC_NO_MULTICLASS_PRED, + }, + { + name: 'without null_idx', + spec: INVALID_SPEC_NO_NULL_IDX, + }, + { + name: 'without vocab', + spec: INVALID_SPEC_NO_VOCAB, + }, + ].forEach(({name, spec}) => { + it(`should not compute margins ${name}`, () => { + const mockAppState: MinimalAppState = { + currentModels: [MODEL_NAME], + currentModelSpecs: {[MODEL_NAME]: { + spec, + datasets: [], + generators: [], + interpreters: [] + }} + }; + + const {marginSettings} = + new ClassificationService(mockAppState as {} as AppState); + + expect(marginSettings[MODEL_NAME]).toBeDefined(); + expect(marginSettings[MODEL_NAME][FIELD_NAME]).toBeUndefined(); + }); + }); +}); diff --git a/lit_nlp/examples/models/glue_models.py b/lit_nlp/examples/models/glue_models.py index de8e58dc..1466e241 100644 --- a/lit_nlp/examples/models/glue_models.py +++ b/lit_nlp/examples/models/glue_models.py @@ -16,6 +16,7 @@ JsonDict = lit_types.JsonDict Spec = lit_types.Spec +TFSequenceClassifierOutput = transformers.modeling_tf_outputs.TFSequenceClassifierOutput @attr.s(auto_attribs=True, kw_only=True) @@ -350,10 +351,13 @@ def predict_minibatch(self, inputs: Iterable[JsonDict]): model_inputs = encoded_input.copy() model_inputs["input_ids"] = None - out: transformers.modeling_tf_outputs.TFSequenceClassifierOutput = \ - self.model(model_inputs, inputs_embeds=input_embs, training=False, - output_hidden_states=True, output_attentions=True, - return_dict=True) + out: TFSequenceClassifierOutput = self.model( + model_inputs, + inputs_embeds=input_embs, + training=False, + output_hidden_states=True, + output_attentions=True, + return_dict=True) batched_outputs = { "input_ids": encoded_input["input_ids"], @@ -540,3 +544,12 @@ def __init__(self, *args, **kw): labels=["non-toxic", "toxic"], null_label_idx=0, **kw) + + def output_spec(self) -> Spec: + ret = super().output_spec() + ret["probas"] = lit_types.MulticlassPreds( + parent=self.config.label_name, + vocab=self.config.labels, + null_idx=self.config.null_label_idx, + threshold=0.3) + return ret