Skip to content

Commit

Permalink
Adds a threshold property to MulticlassPreds type.
Browse files Browse the repository at this point in the history
Updates ClassificationService to respect MulticlassPreds field specs when resetting values.

PiperOrigin-RevId: 476073578
  • Loading branch information
RyanMullins authored and LIT team committed Sep 22, 2022
1 parent 4760b4d commit 5e91b19
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 51 deletions.
1 change: 1 addition & 0 deletions lit_nlp/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class MulticlassPreds(_Tensor1D):
null_idx: Optional[int] = None # vocab index of negative (null) label
parent: Optional[str] = None # CategoryLabel field in input
autosort: Optional[bool] = False # Enable automatic sorting
threshold: Optional[float] = None # binary threshold, used to compute margin

@property
def num_labels(self):
Expand Down
1 change: 0 additions & 1 deletion lit_nlp/client/elements/line_chart.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ export class LineChart extends ReactiveElement {
.attr("stroke", 'var(--lit-cyea-400)');

const mousemove = () => {
console.log(d3.mouse(this));
const xLocation = d3.mouse(this)[0] - this.margin;
const x0 = x.invert(xLocation);
const bisect = d3.bisect(data.map(data => data[0]), x0);
Expand Down
2 changes: 2 additions & 0 deletions lit_nlp/client/lib/lit_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ export class MulticlassPreds extends _Tensor1D {
parent?: string = undefined;
/** Enable automatic sorting. */
autosort?: boolean = false;
/** Binary threshold, used to compute margin. */
threshold?: number = undefined;

get num_labels() {
return this.vocab.length;
Expand Down
15 changes: 5 additions & 10 deletions lit_nlp/client/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -293,20 +293,15 @@ export function handleEnterKey(e: KeyboardEvent, callback: () => void) {
* Converts the margin value to the threshold for binary classification.
*/
export function getThresholdFromMargin(margin: number) {
if (margin == null) {
return .5;
}
return margin === 0 ? .5 : 1 / (1 + Math.exp(-margin));
return !margin ? .5 : 1 / (1 + Math.exp(-margin));
}

/**
* Converts the threshold value for binary classification to the margin.
*/
export function getMarginFromThreshold(threshold: number) {
const margin = threshold !== 1 ?
(threshold !== 0 ? Math.log(threshold / (1 - threshold)) : -5) :
5;
return margin;
return threshold === 1 ? 5 :
threshold === 0 ? -5 : Math.log(threshold / (1 - threshold));
}

/**
Expand Down Expand Up @@ -485,7 +480,7 @@ export function copyToClipboard(value: string) {
* NPWS will make copy/pasting from the table behave strangely.
*/
export function chunkWords(sent: string) {
const chunkWord = (word: string) => {
function chunkWord (word: string) {
const maxLen = 15;
const chunks: string[] = [];
for (let i=0; i<word.length; i+=maxLen) {
Expand All @@ -494,7 +489,7 @@ export function chunkWords(sent: string) {
// This is not an empty string, it is a non-printing space.
const zeroWidthSpace = '​';
return chunks.join(zeroWidthSpace);
};
}
return sent.split(' ').map(word => chunkWord(word)).join(' ');
}

Expand Down
79 changes: 43 additions & 36 deletions lit_nlp/client/services/classification_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {action, computed, observable, reaction} from 'mobx';

import {MulticlassPreds} from '../lib/lit_types';
import {FacetedData, GroupedExamples, Spec} from '../lib/types';
import {getMarginFromThreshold} from '../lib/utils';

import {LitService} from './lit_service';
import {AppState} from './state_service';
Expand Down Expand Up @@ -80,9 +81,11 @@ export class ClassificationService extends LitService {
}, {fireImmediately: true});
}

// Returns all margin settings for use as a reaction input function when
// setting up observers.
// TODO(lit-team): Remove need for this intermediate object (b/156100081)
/**
* Returns all margin settings for use as a reaction input function when
* setting up observers.
*/
// TODO(b/156100081): Remove need for this intermediate object
@computed
get allMarginSettings(): number[] {
const res: number[] = [];
Expand All @@ -106,10 +109,17 @@ export class ClassificationService extends LitService {
if (this.marginSettings[model] == null) {
this.marginSettings[model] = {};
}
this.marginSettings[model][fieldName] = {};
for (const group of Object.values(groupedExamples)) {
this.marginSettings[model][fieldName][group.displayName!] =
{facetData: group, margin: 0};
if (this.marginSettings[model][fieldName] == null) {
this.marginSettings[model][fieldName] = {};
}
const {output} = this.appState.currentModelSpecs[model].spec;
const fieldSpec = output[fieldName];
if (!(fieldSpec instanceof MulticlassPreds)) return;
const margin = fieldSpec.threshold != null ?
getMarginFromThreshold(fieldSpec.threshold) : 0;
for (const facetData of Object.values(groupedExamples)) {
this.marginSettings[model][fieldName][facetData.displayName!] =
{facetData, margin};
}
}

Expand All @@ -120,20 +130,20 @@ export class ClassificationService extends LitService {
for (const [model, output] of Object.entries(modelOutputSpecMap)) {
marginSettings[model] = {};
for (const [fieldName, fieldSpec] of Object.entries(output)) {
if (fieldSpec instanceof MulticlassPreds &&
fieldSpec.null_idx != null && fieldSpec.vocab != null) {
marginSettings[model][fieldName] = {};

if (model in this.marginSettings &&
this.marginSettings[model][fieldName] != null) {
// Reset all facets to margin = 0.
const facets = Object.keys(this.marginSettings[model][fieldName]);
for (const key of facets) {
marginSettings[model][fieldName][key] = {margin: 0};
}
}

marginSettings[model][fieldName][GLOBAL_FACET] = {margin: 0};
if (!(fieldSpec instanceof MulticlassPreds) ||
fieldSpec.null_idx == null || !fieldSpec.vocab.length) continue;

const margin = fieldSpec.threshold != null ?
getMarginFromThreshold(fieldSpec.threshold) : 0;
marginSettings[model][fieldName] = {[GLOBAL_FACET]: {margin}};

if (this.marginSettings[model] == null ||
this.marginSettings[model][fieldName] == null) continue;

const facets = Object.keys(this.marginSettings[model][fieldName]);
for (const facet of facets) {
const {facetData} = this.marginSettings[model][fieldName][facet];
marginSettings[model][fieldName][facet] = {facetData, margin};
}
}
}
Expand All @@ -142,28 +152,28 @@ export class ClassificationService extends LitService {
}

@action
setMargin(model: string, fieldName: string, value: number,
facet?: FacetedData) {
setMargin(model: string, fieldName: string, margin: number,
facetData?: FacetedData) {
if (this.marginSettings[model] == null) {
this.marginSettings[model] = {};
}
if (this.marginSettings[model][fieldName] == null) {
this.marginSettings[model][fieldName] = {};
}
if (facet == null) {
if (facetData == null) {
// If no facet provided, then update the facet for the entire dataset
// if one exists, otherwise update all facets with the provided margin.
if (GLOBAL_FACET in this.marginSettings[model][fieldName]) {
this.marginSettings[model][fieldName][GLOBAL_FACET] =
{facetData: facet, margin: value};
{facetData, margin};
} else {
for (const key of Object.keys(this.marginSettings[model][fieldName])) {
this.marginSettings[model][fieldName][key].margin = value;
this.marginSettings[model][fieldName][key].margin = margin;
}
}
} else {
this.marginSettings[model][fieldName][facet.displayName!] =
{facetData: facet, margin: value};
this.marginSettings[model][fieldName][facetData.displayName!] =
{facetData, margin};
}
}

Expand All @@ -172,16 +182,13 @@ export class ClassificationService extends LitService {
this.marginSettings[model][fieldName] == null) {
return 0;
}
const fieldMargins = this.marginSettings[model][fieldName];
if (facet == null) {
if (this.marginSettings[model][fieldName][GLOBAL_FACET] == null) {
return 0;
}
return this.marginSettings[model][fieldName][GLOBAL_FACET].margin;
return fieldMargins[GLOBAL_FACET]?.margin || 0;
} else if (facet.displayName != null) {
return fieldMargins[facet.displayName]?.margin || 0;
} else {
if (this.marginSettings[model][fieldName][facet.displayName!] == null) {
return 0;
}
return this.marginSettings[model][fieldName][facet.displayName!].margin;
return 0;
}
}
}
177 changes: 177 additions & 0 deletions lit_nlp/client/services/classification_service_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import 'jasmine';
import {MulticlassPreds} from '../lib/lit_types';
import {getMarginFromThreshold} from '../lib/utils';
import {GroupedExamples, ModelSpec} from '../lib/types';
import {AppState} from './state_service';
import {ClassificationService} from './classification_service';

const FIELD_NAME = 'pred';
const MODEL_NAME = 'test_model';

const MULTICLASS_PRED_WITH_THRESHOLD = new MulticlassPreds();
MULTICLASS_PRED_WITH_THRESHOLD.null_idx = 0;
MULTICLASS_PRED_WITH_THRESHOLD.vocab = ['0', '1'];
MULTICLASS_PRED_WITH_THRESHOLD.threshold = 0.3;
const MULTICLASS_SPEC_WITH_THRESHOLD: ModelSpec = {
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_WITH_THRESHOLD}
};

const MULTICLASS_PRED_WITHOUT_THRESHOLD = new MulticlassPreds();
MULTICLASS_PRED_WITHOUT_THRESHOLD.null_idx = 0;
MULTICLASS_PRED_WITHOUT_THRESHOLD.vocab = ['0', '1'];
const MULTICLASS_SPEC_WITHOUT_THRESHOLD: ModelSpec = {
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_WITHOUT_THRESHOLD}
};

const MULTICLASS_PRED_NO_VOCAB = new MulticlassPreds();
MULTICLASS_PRED_NO_VOCAB.null_idx = 0;
const INVALID_SPEC_NO_VOCAB: ModelSpec = {
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_NO_VOCAB}
};

const MULTICLASS_PRED_NO_NULL_IDX = new MulticlassPreds();
MULTICLASS_PRED_NO_NULL_IDX.vocab = ['0', '1'];
const INVALID_SPEC_NO_NULL_IDX: ModelSpec = {
input: {},
output: {[FIELD_NAME]: MULTICLASS_PRED_NO_NULL_IDX}
};

const INVALID_SPEC_NO_MULTICLASS_PRED: ModelSpec = {
input: {},
output: {}
};

const UPDATED_MARGIN = getMarginFromThreshold(0.8);

type MinimalAppState = Pick<AppState, 'currentModels' | 'currentModelSpecs'>;

describe('classification service test', () => {
[ // Parameterized tests for models with valid specs.
{
name: 'without a threshold',
spec: MULTICLASS_SPEC_WITHOUT_THRESHOLD,
facets: undefined,
expThreshold: undefined,
expMargin: 0
},
{
name: 'without a threshold with facets',
spec: MULTICLASS_SPEC_WITHOUT_THRESHOLD,
facets: ['TN', 'TP'],
expThreshold: undefined,
expMargin: 0
},
{
name: 'with a threshold',
spec: MULTICLASS_SPEC_WITH_THRESHOLD,
facets: undefined,
expThreshold: 0.3,
expMargin: getMarginFromThreshold(0.3)
},
{
name: 'with a threshold and facets',
spec: MULTICLASS_SPEC_WITH_THRESHOLD,
facets: ['TN', 'TP'],
expThreshold: 0.3,
expMargin: getMarginFromThreshold(0.3)
},
].forEach(({name, spec, facets, expThreshold, expMargin}) => {
const mockAppState: MinimalAppState = {
currentModels: [MODEL_NAME],
currentModelSpecs: {[MODEL_NAME]: {
spec,
datasets: [],
generators: [],
interpreters: []
}}
};

const classificationService =
new ClassificationService(mockAppState as {} as AppState);

function getMargin (facet?: string) {
const facetData = facet != null ? {displayName: facet, data: []} : undefined;
return classificationService.getMargin(MODEL_NAME, FIELD_NAME, facetData);
}

it(`derives margin settings from a spec ${name}`, () => {
const predSpec = spec.output['pred'];
expect(predSpec).toBeInstanceOf(MulticlassPreds);
expect((predSpec as MulticlassPreds).threshold).toEqual(expThreshold);
expect(getMargin()).toBe(expMargin);
});

it(`updates margin settings for specs ${name}`, () => {
classificationService.setMargin(MODEL_NAME, FIELD_NAME, UPDATED_MARGIN);
expect(getMargin()).toBe(UPDATED_MARGIN);
});

it(`resets margin settings for specs ${name}`, () => {
classificationService.resetMargins({[MODEL_NAME]: spec.output});
expect(getMargin()).toBe(expMargin);
});

if (facets != null) {
const groupedExamples = facets.reduce((obj, facet) => {
obj[facet] = {data: [], displayName: facet};
return obj;
}, {} as GroupedExamples);

classificationService.setMarginGroups(MODEL_NAME, FIELD_NAME,
groupedExamples);

for (const facet of facets) {
it(`derives margin for ${facet} facet from a spec ${name}`, () => {
expect(getMargin(facet)).toBe(expMargin);
});

it(`updates margin for ${facet} facet from a spec ${name}`, () => {
classificationService.setMargin(MODEL_NAME, FIELD_NAME, UPDATED_MARGIN,
{displayName: facet, data: []});
expect(getMargin(facet)).toBe(UPDATED_MARGIN);
});

it(`resets margin for ${facet} facet from a spec ${name}`, () => {
classificationService.resetMargins({[MODEL_NAME]: spec.output});
expect(getMargin(facet)).toBe(expMargin);
});
}
}
});

[ // Parameterized tests for models with invalid specs
{
name: 'without a multiclass pred',
spec: INVALID_SPEC_NO_MULTICLASS_PRED,
},
{
name: 'without null_idx',
spec: INVALID_SPEC_NO_NULL_IDX,
},
{
name: 'without vocab',
spec: INVALID_SPEC_NO_VOCAB,
},
].forEach(({name, spec}) => {
it(`should not compute margins ${name}`, () => {
const mockAppState: MinimalAppState = {
currentModels: [MODEL_NAME],
currentModelSpecs: {[MODEL_NAME]: {
spec,
datasets: [],
generators: [],
interpreters: []
}}
};

const {marginSettings} =
new ClassificationService(mockAppState as {} as AppState);

expect(marginSettings[MODEL_NAME]).toBeDefined();
expect(marginSettings[MODEL_NAME][FIELD_NAME]).toBeUndefined();
});
});
});
Loading

0 comments on commit 5e91b19

Please sign in to comment.