From 0656386188d6e4b6c83dab58fb4e6569ebea217e Mon Sep 17 00:00:00 2001 From: Fan Ye Date: Tue, 18 Jun 2024 09:55:11 -0700 Subject: [PATCH] Remove span graph module from LIT. PiperOrigin-RevId: 644421561 --- lit_nlp/api/layout.py | 6 - lit_nlp/client/elements/span_graph_vis.css | 74 --- lit_nlp/client/elements/span_graph_vis.ts | 472 ------------------ .../elements/span_graph_vis_vertical.css | 103 ---- .../elements/span_graph_vis_vertical.ts | 278 ----------- .../client/modules/annotated_text_module.ts | 13 +- lit_nlp/client/modules/span_graph_module.ts | 346 ------------- 7 files changed, 7 insertions(+), 1285 deletions(-) delete mode 100644 lit_nlp/client/elements/span_graph_vis.css delete mode 100644 lit_nlp/client/elements/span_graph_vis.ts delete mode 100644 lit_nlp/client/elements/span_graph_vis_vertical.css delete mode 100644 lit_nlp/client/elements/span_graph_vis_vertical.ts delete mode 100644 lit_nlp/client/modules/span_graph_module.ts diff --git a/lit_nlp/api/layout.py b/lit_nlp/api/layout.py index bb8d773e..188c8cd4 100644 --- a/lit_nlp/api/layout.py +++ b/lit_nlp/api/layout.py @@ -59,10 +59,6 @@ class LitModuleName(dtypes.EnumSerializableAsValues, enum.Enum): SimpleDatapointEditorModule = 'simple-datapoint-editor-module' # Non-replicating version of Datapoint Editor SingleDatapointEditorModule = 'single-datapoint-editor-module' - SpanGraphGoldModule = 'span-graph-gold-module' - SpanGraphGoldModuleVertical = 'span-graph-gold-module-vertical' - SpanGraphModule = 'span-graph-module' - SpanGraphModuleVertical = 'span-graph-module-vertical' TCAVModule = 'tcav-module' ThresholderModule = 'thresholder-module' TrainingDataAttributionModule = 'tda-module' @@ -126,8 +122,6 @@ def to_json(self) -> dtypes.JsonDict: modules = LitModuleName # pylint: disable=invalid-name MODEL_PREDS_MODULES = ( - modules.SpanGraphGoldModuleVertical, - modules.SpanGraphModuleVertical, modules.ClassificationModule, modules.MultilabelModule, modules.RegressionModule, diff --git a/lit_nlp/client/elements/span_graph_vis.css b/lit_nlp/client/elements/span_graph_vis.css deleted file mode 100644 index 310775f3..00000000 --- a/lit_nlp/client/elements/span_graph_vis.css +++ /dev/null @@ -1,74 +0,0 @@ -text.token-text { - alignment-baseline: middle; - dominant-baseline: central; -} - -polyline.span-bracket { - fill: none; - stroke-width: 1.2px; - stroke: var(--group-color); -} - -.selected polyline.span-bracket { - stroke-width: 1.8px; -} - -path.arc-path { - stroke-width: 1.2px; - stroke: var(--group-color); - fill: none; -} - -path.arc-path.arc-neg { - stroke-dasharray: 3,1; - stroke: gray; -} - -.selected path.arc-path { - stroke-width: 1.8px; -} - -path.arc-arrow { - stroke-width: 1.2px; - stroke: var(--group-color); - fill: var(--group-color); -} - -path.arc-arrow.arc-neg { - stroke: gray; - fill: gray; -} - - -.layer-label text { - font-family: 'Share Tech Mono', monospace; - dominant-baseline: middle; - text-anchor: end; -} - -foreignObject.span-label { - overflow: visible; -} - -.span-label div { - background-color: white; - font-family: 'Share Tech Mono', monospace; - line-height: 1.0; - padding: 1px; - padding-right: 3px; /* for occluding labels on mouseover */ - white-space: nowrap; - overflow: hidden; - text-overflow: ellipsis; - color: var(--group-color); -} - -g.selected .span-label div { - background-color: white; - overflow-x: visible; - width: fit-content; /* needed to include background when expanding */ -} - -.mousebox { - fill: white; - fill-opacity: 0.0; -} diff --git a/lit_nlp/client/elements/span_graph_vis.ts b/lit_nlp/client/elements/span_graph_vis.ts deleted file mode 100644 index 0e10cf3a..00000000 --- a/lit_nlp/client/elements/span_graph_vis.ts +++ /dev/null @@ -1,472 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Visualization component for structured prediction over text. - */ - -import * as d3 from 'd3'; -import {html, LitElement, svg} from 'lit'; -import {customElement, property} from 'lit/decorators.js'; -import {classMap} from 'lit/directives/class-map.js'; -import {styleMap} from 'lit/directives/style-map.js'; - -import {getVizColor} from '../lib/colors'; -import {EdgeLabel} from '../lib/dtypes'; - -import {styles} from './span_graph_vis.css'; - - -/** - * Represents a group of directed graphs anchored to token spans. - * This is the general "edge probing" representation, which can be used for many - * problems including sequence tagging, span labeling, and directed graphs like - * semantic frames, coreference, and dependency parsing. See - * https://arxiv.org/abs/1905.06316 for more on this formalism. - */ -export interface SpanGraph { - 'tokens': string[]; - 'layers': AnnotationLayer[]; -} - -/** - * A single layer of annotations, like 'pos' (part-of-speech) - * or 'ner' (named entities). - */ -export interface AnnotationLayer { - 'name': string; - 'edges': EdgeLabel[]; -} - -/* Compute points for a polyline bracket. */ -function hBracketPoints(width: number, height: number, lean: number) { - // Points for a polyline bracket. - const start = [0, 0]; - const ltop = [lean, height]; - const rtop = [width - lean, height]; - const end = [width, 0]; - return [start, ltop, rtop, end]; -} - -/** - * Compute path for a dependency arc. - */ -function arcPath( - startY: number, x1: number, x2: number, height: number, aspect: number) { - const left = Math.min(x1, x2); - const right = Math.max(x1, x2); - let pathCommands = `M ${left} ${startY} `; - if ((right - left) > (2 * aspect * height)) { - // Long arcs: draw as 90* curve, flat, then 90* curve. - const majorAxis = aspect * height; - pathCommands += `A ${majorAxis} ${height} 0 0 1 ${left + majorAxis} ${ - startY - height} `; - pathCommands += `L ${right - majorAxis} ${startY - height} `; - pathCommands += `A ${majorAxis} ${height} 0 0 1 ${right} ${startY} `; - } else { - // Short arcs: draw as single 180* curve. - height = (right - left) / (2 * aspect); - pathCommands += - `A ${(right - left) / 2} ${height} 0 0 1 ${right} ${startY} `; - } - return pathCommands; /* assign as 'd' attribute to path */ -} - -/** - * Compute path for the arrow at the end of an arc. - */ -function arcArrow(startY: number, x: number, markSize: number) { - let pathCommands = `M ${x - markSize} ${startY - (1.5 * markSize)} `; - pathCommands += `L ${x + markSize} ${startY - (1.5 * markSize)} `; - pathCommands += `L ${x} ${startY} Z`; - return pathCommands; /* assign as 'd' attribute to path */ -} - -/* Set attributes to match target's size to the source element. */ -function matchBBox(source: SVGGElement, target: SVGRectElement) { - const bbox = source.getBBox(); - target.setAttribute('x', `${bbox.x}`); - target.setAttribute('y', `${bbox.y}`); - target.setAttribute('width', `${bbox.width}`); - target.setAttribute('height', `${bbox.height}`); -} - -/** Structured prediction (SpanGraph) visualization class. */ -@customElement('span-graph-vis') -export class SpanGraphVis extends LitElement { - /* Data binding */ - @property({type: Object}) data: SpanGraph = {tokens: [], layers: []}; - @property({type: Boolean}) showLayerLabel: boolean = true; - - /* Rendering parameters */ - @property({type: Number}) lineHeight: number = 18; - @property({type: Number}) bracketHeight: number = 5.5; - @property({type: Number}) yPad: number = 5; - // For arcs between spans. - @property({type: Number}) arcBaseHeight: number = 20; - @property({type: Number}) arcMaxHeight: number = 40; - @property({type: Number}) arcAspect: number = 1.2; - @property({type: Number}) arcArrowSize: number = 4; - // Padding for SVG viewport, to avoid clipping some elements (like polyline). - @property({type: Number}) viewPad: number = 5; - // Multiplier from SVG units to screen pixels. - @property({type: Number}) svgScaling: number = 1.2; - - /* Internal rendering state */ - private tokenXBounds: Array<[number, number]> = []; - - static override get styles() { - return styles; - } - - renderTokens(tokens: string[]) { - return svg` - - - ${tokens.map(t => svg`${svg`${t + ' '}`}`)} - - `; - } - - private getTokenGroup() { - return this.shadowRoot!.querySelector('g#token-group') as SVGGElement; - } - - renderEdge(edge: EdgeLabel, color: string) { - // Positioning relative to the group transform, which will be applied later. - const labelHeight = this.lineHeight; - const labelY = -(this.bracketHeight + this.lineHeight); - - let labelText = edge.label; - let isNegativeEdge = false; - if (typeof edge.label === 'number') { - labelText = edge.label.toFixed(3); - isNegativeEdge = edge.label < 0.5; - } - const arcPathClass = - classMap({'arc-path': true, 'arc-neg': isNegativeEdge}); - const arcArrowClass = - classMap({'arc-arrow': true, 'arc-neg': isNegativeEdge}); - color = isNegativeEdge ? 'gray' : color; - // clang-format off - return svg` - - ${edge.span2 ? svg` - - - ` : ''} - - - - ${html`
${labelText}
`} -
- -
- ${edge.span2 ? svg` - - ` : ''} -
- `; - // clang-format on - } - - renderLayer(layer: AnnotationLayer, i: number) { - const rowColor = getVizColor('deep', i).color; - // Positioning relative to the group transform, which will be applied later. - const rowLabelX = -10; - const rowLabelY = -(this.bracketHeight + 0.5 * this.lineHeight); - - const orderedEdges = this.sortEdges(layer.edges); - // clang-format off - return svg` - - ${this.showLayerLabel ? svg` - - - ${svg`${layer.name}`} - - ` : null} - ${orderedEdges.map(edge => this.renderEdge(edge, rowColor))} - - `; - // clang-format on - } - - private getLayerGroup(name: string) { - return this.shadowRoot!.querySelector(`g#layer-group-${name}`) as - SVGGElement; - } - - override render() { - return svg` - - ${this.data ? this.renderTokens(this.data.tokens) : ''} - ${this.data ? this.data.layers.map(this.renderLayer.bind(this)) : ''} - `; - } - - private findTokenBounds() { - const tokenNodes = this.getTokenGroup().querySelectorAll('tspan'); - const tokenXBounds: Array<[number, number]> = []; - tokenNodes.forEach(tspan => { - // Use getBBox() to avoid a crash when tspan.getNumberOfChars() === 0. - // TODO(lit-dev): figure out why this case happens - maybe - // the nodes are not yet attached to the DOM? - const bbox = tspan.getBBox(); - tokenXBounds.push([bbox.x, bbox.x + bbox.width]); - }); - return tokenXBounds; - } - - /** - * Consistent sort order. - * Because span labels overflow to the right, we order these so the rightmost - * spans appear first in the DOM, and thus render under anything to the left - * that needs to overflow. - */ - private sortEdges(edges: EdgeLabel[]) { - return edges.slice().sort((a, b) => d3.descending(a.span1[1], b.span1[1])); - } - - /* Starting x position for a bracket, in SVG coordinates */ - private getStartX(span: [number, number]) { - return this.tokenXBounds[span[0]][0]; - } - - /* Ending x position for a bracket, in SVG coordinates */ - private getEndX(span: [number, number]) { - return this.tokenXBounds[span[1] - 1][1]; - } - - /* Find available width without clipping the next label */ - private findAvailableWidths(layerGroup: Element, edges: EdgeLabel[]): - number[] { - const availableWidths: number[] = edges.map(() => 0); - // Find available space for each label, by checking where the next label - // starts. We iterate from right to left through the spans, starting with - // the second-rightmost (i=1). - for (let i = 1; i < edges.length; i++) { - const edge = edges[i]; // this span - const nextEdge = edges[i - 1]; // right neighboring span - availableWidths[i] = - this.getStartX(nextEdge.span1) - this.getStartX(edge.span1); - } - // We don't want the rightmost label (index 0) to be cut off by the edge - // of the SVG draw area, even if the label extends past the end of the - // text. So we need to: - // 1) Set this label to fit the content, so the bounding box contains all - // the label text. - // 2) Set the available width to this rendered width, so we don't clip it - // later. - const firstSpanDiv = - // tslint:disable-next-line:no-unnecessary-type-assertion - layerGroup.querySelector('g.edge-group foreignObject div') as - HTMLDivElement | - null; - if (firstSpanDiv !== null) { - firstSpanDiv.style.width = 'fit-content'; - availableWidths[0] = firstSpanDiv.getBoundingClientRect().width; - } - return availableWidths; - } - - /* Set mouseovers, using d3. */ - private setMouseovers(group: SVGGElement, edges: EdgeLabel[]) { - const rowColor = group.dataset['color'] as string; - const grayColor = getVizColor('deep', 'other').color; - - const spanGroups = d3.select(group).selectAll('g.edge-group').data(edges); - const tokenSpans = d3.select(this.getTokenGroup()).selectAll('tspan'); - - // On mouseover, highlight this span and the corresponding text. - spanGroups.each(function(d, i) { - const colorFn = (e: unknown, j: number) => - (i === j) ? rowColor : grayColor; - const tokenColorFn = (t: unknown, j: number) => { - const inSpan1 = (d.span1[0] <= j && j < d.span1[1]); - const inSpan2 = d.span2 ? (d.span2[0] <= j && j < d.span2[1]) : false; - return (inSpan1 || inSpan2) ? rowColor : 'black'; - }; - const mouseBox = d3.select(this).select('rect.mousebox'); - mouseBox.on('mouseover', () => { - spanGroups.style('--group-color', colorFn); - tokenSpans.attr('fill', tokenColorFn); - d3.select(this).classed('selected', true); - // Ideally we'd also move this element so that it renders above - // the other groups, but SVG2 z-index is not supported by most browsers - // and simply reordering child nodes does not play well with lit-html's - // rendering logic, which relies on pointers to specific positions in - // the DOM. - // d3.select(this).classed('selected', true).raise(); - // TODO(iftenney): consider implementing a tooltip that clones this - // element but always renders above the other spans. - }); - mouseBox.on('mouseout', () => { - // Reset to original color, stored on group element. - // TODO(lit-dev): do this with another CSS class instead? - spanGroups.style('--group-color', function(e) { - return (this as SVGElement).dataset['color'] as string; - }); - tokenSpans.attr('fill', 'black'); - d3.select(this).classed('selected', false); - }); - }); - } - - /* Set y-position of rendered layers */ - private positionLayers() { - let rowStartY = this.getTokenGroup().getBBox().y - this.yPad / 2; - for (let i = 0; i < this.data.layers.length; i++) { - const group: SVGGElement = this.getLayerGroup(this.data.layers[i].name); - group.setAttribute('transform', `translate(0, ${rowStartY})`); - rowStartY -= group.getBBox().height + this.yPad; - } - } - - /* Set the SVG viewport to the bounding box of the main group. */ - private setSVGViewport() { - const mainGroup = this.shadowRoot!.querySelector('g#all') as SVGGElement; - const bbox = mainGroup.getBBox(); - const svg = this.shadowRoot!.getElementById('svg')!; - // Set bounding box to cover main group + viewPad on all sides. - const viewBox = [ - bbox.x - this.viewPad, bbox.y - this.viewPad, - bbox.width + 2 * this.viewPad, bbox.height + 2 * this.viewPad - ]; - svg.setAttribute('viewBox', `${viewBox}`); - // Set the height of the SVG as it will render on the page. - svg.setAttribute( - 'height', `${this.svgScaling * (bbox.height + 2 * this.viewPad)}`); - } - - /** - * Post-render callback. Performs imperative updates to layout and component - * sizes which need to depend on the positions of each token. Also sets up - * mouseover behavior. - */ - override updated() { - if (this.data == null) { - this.tokenXBounds = []; - return; - } - this.tokenXBounds = this.findTokenBounds(); - - // For each layer, position the span groups - for (const layer of this.data.layers) { - const orderedEdges = this.sortEdges(layer.edges); - - // Container group for this layer. - const layerGroup: SVGGElement = this.getLayerGroup(layer.name); - - // Compute available widths, needed for clipping of labels. - const availableWidths = - this.findAvailableWidths(layerGroup, orderedEdges); - - // Edge groups within this layer. - const edgeGroups = layerGroup.querySelectorAll('g.edge-group'); - edgeGroups.forEach((g, i) => { - const edge = orderedEdges[i]; - - const g1 = g.querySelector('g.at-span1')!; - // Set position within this row. - g1.setAttribute( - 'transform', `translate(${this.getStartX(edge.span1)}, 0)`); - - // Compute span width in SVG units, based on rendered token width. - const span1Width = - this.getEndX(edge.span1) - this.getStartX(edge.span1); - // Set points for span1 bracket. - const points1 = - hBracketPoints(span1Width, -1 * (this.bracketHeight - 1), 1); - g1.querySelector('polyline')!.setAttribute('points', `${points1}`); - - // Set the width for the label; this will show ellipsis for the label - // text if it is longer. - // Leave a few pixels spacing if we can afford it, but don't go - // shorter than the token width. - const displayWidth = Math.max(span1Width, availableWidths[i] - 5); - g.querySelector('foreignObject')!.setAttribute( - 'width', `${displayWidth}`); - - // If there's a second span, set up bracket - // and draw arc from span1 -> span2 with the arrow on span1. - if (edge.span2) { - const g2 = g.querySelector('g.at-span2')!; - // Set position within this row. - g2.setAttribute( - 'transform', `translate(${this.getStartX(edge.span2)}, 0)`); - // Compute span width in SVG units, based on rendered token width. - const span2Width = - this.getEndX(edge.span2) - this.getStartX(edge.span2); - const points2 = - hBracketPoints(span2Width, -1 * (this.bracketHeight - 1), 1); - g2.querySelector('polyline')!.setAttribute('points', `${points2}`); - - // Draw arc. - const startY = - -1 * (this.bracketHeight + this.lineHeight + 1 /* pad */); - const x1 = - (this.getEndX(edge.span1) + this.getStartX(edge.span1)) / 2; - let x2 = (this.getEndX(edge.span2) + this.getStartX(edge.span2)) / 2; - // Adjust arc end to avoid overlapping arrows. - // See //nlp/saft/rendering/sentence-html-renderer.js - if (x2 > x1) { - x2 -= (this.arcArrowSize + 2); - } else { - x2 += (this.arcArrowSize + 2); - } - // Adjust arc height based on edge length (# tokens between - // midpoints). See nlp_saft::SentenceRenderer::CalculateDimensions() - // from //nlp/saft/rendering/sentence-html-rendering.cc - const mid1 = (edge.span1[1] + edge.span1[0]) / 2; - const mid2 = (edge.span2[1] + edge.span2[0]) / 2; - const l = Math.min(30, Math.abs(mid2 - mid1)); - const arcHeight = Math.min( - this.arcBaseHeight + Math.round((10 - (l / 6.0)) * l), - this.arcMaxHeight); - g.querySelector('path.arc-path')!.setAttribute( - 'd', `${arcPath(startY, x1, x2, arcHeight, this.arcAspect)}`); - g.querySelector('path.arc-arrow')!.setAttribute( - 'd', `${arcArrow(startY, x1, this.arcArrowSize)}`); - } - }); - - // Set mouseover behavior for this layer. - this.setMouseovers(layerGroup, orderedEdges); - } - - // Set mouseover boxes to match the _visible_ size of the label container. - this.shadowRoot!.querySelectorAll('g.edge-group').forEach(g => { - matchBBox( - g.querySelector('foreignObject') as SVGGElement, - g.querySelector('rect.mousebox') as SVGRectElement); - }); - - // Stack layers vertically, using bounding boxes to avoid occlusion. - this.positionLayers(); - // Finally, after everything is positioned, set the viewport for the whole - // SVG. - this.setSVGViewport(); - } -} - -declare global { - interface HTMLElementTagNameMap { - 'span-graph-vis': SpanGraphVis; - } -} diff --git a/lit_nlp/client/elements/span_graph_vis_vertical.css b/lit_nlp/client/elements/span_graph_vis_vertical.css deleted file mode 100644 index 2431fed4..00000000 --- a/lit_nlp/client/elements/span_graph_vis_vertical.css +++ /dev/null @@ -1,103 +0,0 @@ -.holder { - display: flex; - font-family: 'Share Tech Mono', monospace; - position: relative; - color: #555; -} -.layer { - cursor: pointer; - color: var(--group-color); -} -.layer-label-vert { - top: calc(0px - var(--line-height)); - position: absolute; - padding-left: 7px; - display: flex; - transform: rotate(0deg); - transition: .25s transform; - transform-origin: 10px 10px; - color: var(--group-color); -} -.layer-label-vert.hidden { - transform: rotate(-90deg); -} -.column { - position: relative; - height: 100%; - transition: .25s width; -} -.column.hidden { - width: 13px !important; - opacity: 0; -} -.tokens { - z-index: 1; -} -.line { - height: var(--line-height); - box-sizing: border-box; - padding: 0 7px; - text-align: right; - white-space: nowrap; -} -.token.selected{ - color: black; -} -.child { - font-weight: bold; - filter: hue-rotate(-40deg); - border-width: 3px; -} -.parent { - font-weight: bold; - filter: hue-rotate(40deg); - border-width: 3px; -} -.selected { - font-weight: bold; - border-width: 3px; -} - -.gray { - border-color: #ddd !important; - color: #ddd -} - -.edge { - position: absolute; -} -.edge-line { - border: 1px solid var(--group-color); - border-left: 0; - width: 3px; - left: -3px; -} -.arrow-head { - border: 5px solid transparent; - border-right: 5px solid var(--group-color); - width: 0; - height: 0; - top: -5px; - left: -5px; - position: absolute; -} -.gray .arrow-head{ - border-right-color: #ddd; -} -.arrow-head.bottom{ - bottom: -5px; - top: unset; -} -.background-lines{ - position: absolute; - width: 100%; - top: -3px; -} -.background-line { - height: var(--line-height); - width: 100%; - padding: 0 7px; -} -.background-line:nth-child(odd){ - background: #f5f5f5; -} diff --git a/lit_nlp/client/elements/span_graph_vis_vertical.ts b/lit_nlp/client/elements/span_graph_vis_vertical.ts deleted file mode 100644 index 86915ef5..00000000 --- a/lit_nlp/client/elements/span_graph_vis_vertical.ts +++ /dev/null @@ -1,278 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Visualization component for structured prediction over text. - */ - -// tslint:disable:no-new-decorators - -import {property} from 'lit/decorators.js'; -import {customElement} from 'lit/decorators.js'; -import { html} from 'lit'; -import {classMap} from 'lit/directives/class-map.js'; -import {styleMap} from 'lit/directives/style-map.js'; -import {observable} from 'mobx'; - -import {getVizColor} from '../lib/colors'; -import {EdgeLabel} from '../lib/dtypes'; -import {ReactiveElement} from '../lib/elements'; - -import {styles} from './span_graph_vis_vertical.css'; - - -/** - * Represents a group of directed graphs anchored to token spans. - * This is the general "edge probing" representation, which can be used for many - * problems including sequence tagging, span labeling, and directed graphs like - * semantic frames, coreference, and dependency parsing. See - * https://arxiv.org/abs/1905.06316 for more on this formalism. - */ -export interface SpanGraph { - tokens: string[]; - layers: AnnotationLayer[]; -} - -/** - * A single layer of annotations, like 'pos' (part-of-speech) - * or 'ner' (named entities). - */ -export interface AnnotationLayer { - name: string; - edges: EdgeLabel[]; - hideBracket?: boolean; -} - -function formatEdgeLabel(label: string|number): string { - if (typeof (label) === 'number') { - return Number.isInteger(label) ? label.toString() : - label.toFixed(3).toString(); - } - return `${label}`; -} - -/** Structured prediction (SpanGraph) visualization class. */ -@customElement('span-graph-vis-vertical') -export class SpanGraphVis extends ReactiveElement { - /* Data binding */ - @property({ type: Object }) data: SpanGraph = { tokens: [], layers: [] }; - @property({ type: Boolean }) showLayerLabel: boolean = true; - - @observable private selectedTokIdx?: number; - @observable private readonly columnVisibility: { [key: string]: boolean } = {}; - - /* Rendering parameters */ - @property({ type: Number }) lineHeight: number = 18; - @property({ type: Number }) approxFontSize = this.lineHeight / 3; - - // Padding for SVG viewport, to avoid clipping some elements (like polyline). - @property({ type: Number }) viewPad: number = 5; - - static override get styles() { - return styles; - } - - override render() { - if (!this.data) { - return ``; - } - const host = this.shadowRoot!.host as HTMLElement; - host.style.setProperty('--line-height', `${this.lineHeight}pt`); - const tokens = this.data.tokens; - - const tokenClasses = (i: number) => classMap({ - line: true, - token: true, - selected: i === this.selectedTokIdx - }); - - // clang-format off - return html` -
-
- ${tokens.map(t => html`
`)} -
-
- ${tokens.map((t, i) => html` -
this.selectedTokIdx = i} - @mouseleave=${() => this.selectedTokIdx = undefined}> - ${t} -
- `)} -
- ${this.data.layers.map((layer, i) => this.renderLayer(layer, i))} -
`; - // clang-format on - } - - /** - * Render a given annotation layer. - */ - renderLayer(layer: AnnotationLayer, i: number) { - - if (!layer.edges.length) { - return html``; - } - - const layerStyles = styleMap({ - '--group-color': getVizColor('dark', i).color - }); - - // The column width is the width of the longest label, in pixels. - const colWidth = - Math.max( - layer.name.length, - ...layer.edges.map(e => formatEdgeLabel(e.label).length)) * - this.approxFontSize + - this.viewPad * 2; - - const colStyles = styleMap({ width: `${colWidth}pt` }); - const hidden = this.columnVisibility[layer.name]; - const columnClasses = classMap({ - 'column': true, - 'hidden': hidden - }); - - const headerClasses = classMap({ 'layer-label-vert': true, hidden }); - const onClick = () => - this.columnVisibility[layer.name] = !this.columnVisibility[layer.name]; - - // clang-format off - return html` -
- ${this.showLayerLabel ? html` -
- ${layer.name} -
` : null} -
- ${layer.edges.map(edge => this.renderEdge(edge, layer, colWidth))} -
-
- `; - // clang-format on - } - - /** - * Render an edge and its label. See the note on the SpanGraph interface - * above for more details. - */ - private renderEdge(edge: EdgeLabel, layer: AnnotationLayer, colWidth: number) { - const isArc = 'span2' in edge; - const span0 = edge.span1[0]; - const span1 = edge.span2 ? edge.span2[0] : edge.span1[1]; - const topSpan = Math.min(span0, span1); - const botSpan = Math.max(span0, span1); - - - const isInSpan = (i: number, span:[number, number]) => i >= span[0] && i < span[1]; - - // Span classes (child, parent, etc, based on the currently selected token.) - const tokSelected = this.selectedTokIdx !== undefined; - const selected = isInSpan(this.selectedTokIdx!, edge.span1) || (isArc && isInSpan(this.selectedTokIdx!, edge.span2!)); - const child = isArc && this.selectedTokIdx === span1; - const parent = isArc && this.isChildOfSelected(layer, span0); - const grayLine = tokSelected && !(selected || child); - const grayLabel = grayLine && !(parent); - - // Edge labels can be either strings or numbers; format the latter nicely. - const formattedLabel = formatEdgeLabel(edge.label); - - // Styling for the label text. - const labelWidthInPx = formattedLabel.length * this.approxFontSize; - const labelStyle = styleMap({ - top: `${span0 * this.lineHeight}pt`, - left: isArc ? `${colWidth - labelWidthInPx - this.viewPad}pt` : '', - }); - const labelClasses = classMap({ - child, parent, selected, - gray: grayLabel, - line: true, - edge: true - }); - - // Styling for the arc (a line and sometimes an arrowhead) - const arcPad = .3; - const offset = this.lineHeight / 8; - const top = isArc ? - (topSpan + arcPad) * this.lineHeight + (topSpan === span0 ? 0 : this.viewPad) : - topSpan * (this.lineHeight) - offset; - const bottom = isArc ? - (botSpan + arcPad) * this.lineHeight + (botSpan === span0 ? 0 : -this.viewPad) : - botSpan * (this.lineHeight) - 2 * offset; - - const arcHeight = bottom - top; - const width = isArc ? `${Math.max(arcHeight / 2, this.lineHeight / 2)}pt` : ''; - - const rad = isArc ? arcHeight / 2 : 3; - const lineStyle = styleMap({ - top: `${top}pt`, - height: `${arcHeight}pt`, - width, - 'border-radius': `0pt ${rad}pt ${rad}pt 0pt`, - left: isArc ? `${colWidth + 10}pt` : '', - visibility: layer.hideBracket ? 'hidden' : 'visble', - }); - - const arrowHeadClasses = classMap({ - 'arrow-head': true, - 'bottom': topSpan === span1, - }); - - const arrowClasses = classMap({ - child, - parent: selected, - gray: grayLine, - edge: true, - 'edge-line': true - }); - - return html` -
- ${isArc ? html`
` : ''} -
-
- ${formattedLabel} -
- `; - } - - /** - * Is this token (indicated by tokenIdx) a child of the selected token at - * the specified layer. This assumes that the edge goes from span1 to span2, - * as in a dependency parse tree. - */ - isChildOfSelected(layer: AnnotationLayer, tokenIdx: number) { - for (let j = 0; j < layer.edges.length; j++) { - const edge = layer.edges[j]; - if (edge.span2 && - (this.selectedTokIdx === edge.span1[0]) && - (tokenIdx === edge.span2[0])) { - return true; - } - } - return false; - } - -} - -declare global { - interface HTMLElementTagNameMap { - 'span-graph-vis-vertical': SpanGraphVis; - } -} diff --git a/lit_nlp/client/modules/annotated_text_module.ts b/lit_nlp/client/modules/annotated_text_module.ts index 286d8c76..7165cb12 100644 --- a/lit_nlp/client/modules/annotated_text_module.ts +++ b/lit_nlp/client/modules/annotated_text_module.ts @@ -6,7 +6,7 @@ * spans in running text, which is well-suited for tasks like QA or entity * recognition which have a small number of spans over a longer passage. * - * Similar to span_graph_module, we provide two module classes: + * We provide two module classes: * - AnnotatedTextGoldModule for gold annotations (in the input data) * - AnnotatedTextModule for model predictions */ @@ -14,18 +14,17 @@ // tslint:disable:no-new-decorators import '../elements/annotated_text_vis'; +import {html} from 'lit'; import {customElement} from 'lit/decorators.js'; -import { html} from 'lit'; import {observable} from 'mobx'; import {LitModule} from '../core/lit_module'; import {type AnnotationGroups, TextSegments} from '../elements/annotated_text_vis'; import {MultiSegmentAnnotations, TextSegment} from '../lib/lit_types'; +import {styles as sharedStyles} from '../lib/shared_styles.css'; import {type IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {doesOutputSpecContain, filterToKeys, findSpecKeys} from '../lib/utils'; -import {styles as sharedStyles} from '../lib/shared_styles.css'; - /** LIT module for model output. */ @customElement('annotated-text-gold-module') export class AnnotatedTextGoldModule extends LitModule { @@ -80,7 +79,8 @@ export class AnnotatedTextGoldModule extends LitModule { // clang-format on } - static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) { + static override shouldDisplayModule( + modelSpecs: ModelInfoMap, datasetSpec: Spec) { return findSpecKeys(datasetSpec, MultiSegmentAnnotations).length > 0; } } @@ -159,7 +159,8 @@ export class AnnotatedTextModule extends LitModule { // clang-format on } - static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) { + static override shouldDisplayModule( + modelSpecs: ModelInfoMap, datasetSpec: Spec) { return doesOutputSpecContain(modelSpecs, MultiSegmentAnnotations); } } diff --git a/lit_nlp/client/modules/span_graph_module.ts b/lit_nlp/client/modules/span_graph_module.ts deleted file mode 100644 index 52bc33a5..00000000 --- a/lit_nlp/client/modules/span_graph_module.ts +++ /dev/null @@ -1,346 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Module within LIT for showing sequence and span tagging - * results. - */ - -// tslint:disable:no-new-decorators -import '../elements/span_graph_vis'; -import '../elements/span_graph_vis_vertical'; - -import {customElement} from 'lit/decorators.js'; -import {css, html} from 'lit'; -import {computed, observable} from 'mobx'; - -import {LitModule} from '../core/lit_module'; -import {AnnotationLayer, SpanGraph} from '../elements/span_graph_vis_vertical'; -import {EdgeLabel, SpanLabel} from '../lib/dtypes'; -import {EdgeLabels, SequenceTags, SpanLabels, LitTypeTypesList, LitTypeWithAlign, TextSegment, Tokens} from '../lib/lit_types'; -import {IndexedInput, Input, ModelInfoMap, Preds, Spec} from '../lib/types'; -import {findSpecKeys} from '../lib/utils'; - -import {styles as sharedStyles} from '../lib/shared_styles.css'; - -interface FieldNameMultimap { - [fieldName: string]: string[]; -} - -interface Annotations { - [tokenKey: string]: SpanGraph; -} - -// Shared by gold and preds modules. -const moduleStyles = css` - .outer-container { - display: flex; - flex-direction: column; - justify-content: center; - position: relative; - overflow: hidden; - } - - .token-group { - padding-top: 40px; - } - - .field-title { - padding: 4px; - } -`; - -const supportedPredTypes: LitTypeTypesList = - [SequenceTags, SpanLabels, EdgeLabels]; - -/** - * Convert sequence tags to a list of length-1 span labels. - */ -function tagsToEdges(tags: string[]): EdgeLabel[] { - return tags.map((label: string, i: number) => { - return {span1: [i, i + 1], label} as EdgeLabel; - }); -} - -/** - * Convert span labels to single-sided edge labels. - */ -function spansToEdges(spans: SpanLabel[]): EdgeLabel[] { - return spans.map( - d => ({span1: [d.start, d.end], label: d.label as string} as EdgeLabel)); -} - -function mapTokenToTags(spec: Spec): FieldNameMultimap { - const tagKeys = findSpecKeys(spec, supportedPredTypes); - const tokenKeys = findSpecKeys(spec, Tokens); - - // Make a mapping of token keys to one or more tag sets - const tokenToTags: FieldNameMultimap = {}; - for (const tagKey of tagKeys) { - const {align: tokenKey} = spec[tagKey] as LitTypeWithAlign; - if (tokenKey == null || !tokenKeys.includes(tokenKey)) { - continue; - } else if (tokenToTags[tokenKey] == null) { - tokenToTags[tokenKey] = []; - } - tokenToTags[tokenKey].push(tagKey); - } - return tokenToTags; -} - -function parseInput(data: Input|Preds, spec: Spec): Annotations { - const tokenToTags = mapTokenToTags(spec); - - // Render a row for each set of tokens - const ret: Annotations = {}; - for (const tokenKey of Object.keys(tokenToTags)) { - const annotationLayers: AnnotationLayer[] = []; - for (const tagKey of tokenToTags[tokenKey]) { - let edges = data[tagKey]; - let hideBracket = false; - // Temporary workaround: if we manually create a new datapoint, the span - // or tag field may be "" rather than []. - // TODO(lit-team): remove this once the datapoint editor is type-safe - // for structured fields. - if (edges.length === 0) { - edges = []; - } - if (spec[tagKey] instanceof SequenceTags) { - edges = tagsToEdges(edges); - hideBracket = true; - } else if (spec[tagKey] instanceof SpanLabels) { - edges = spansToEdges(edges); - } - annotationLayers.push({name: tagKey, edges, hideBracket}); - } - // Try to infer tokens from text, if that field is empty. - let tokens = data[tokenKey]; - if (tokens.length === 0) { - const textKey = findSpecKeys(spec, TextSegment)[0]; - tokens = data[textKey].split(); - } - ret[tokenKey] = {tokens, layers: annotationLayers}; - } - return ret; -} - -function renderTokenGroups( - data: Annotations, spec: Spec, orientation: 'horizontal'|'vertical') { - const tokenToTags = mapTokenToTags(spec); - const visElement = (data: SpanGraph, showLayerLabel: boolean) => { - if (orientation === 'vertical') { - return html``; - } else { - return html``; - } - }; - // clang-format off - return html`${Object.keys(tokenToTags).map(tokenKey => { - const labelHere = data[tokenKey]?.layers?.length === 1; - return html` -
- ${labelHere ? - html`
${data[tokenKey].layers[0].name}
` - : null} - ${visElement(data[tokenKey], !labelHere)} -
- `; - })}`; - // clang-format on -} - -/** Gold predictions module class. */ -@customElement('span-graph-gold-module') -export class SpanGraphGoldModule extends LitModule { - static override title = 'Structured Prediction (gold)'; - static override duplicateForExampleComparison = true; - static override duplicateForModelComparison = false; - static override duplicateAsRow = false; - static override numCols = 4; - static override template = - (model: string, selectionServiceIndex: number, shouldReact: number) => html` - - `; - static orientation = 'horizontal'; - - @computed - get dataSpec() { - return this.appState.currentDatasetSpec; - } - - @computed - get goldDisplayData(): Annotations { - const input = this.selectionService.primarySelectedInputData; - if (input === null) { - return {}; - } else { - return parseInput(input.data, this.dataSpec); - } - } - - static override get styles() { - return [sharedStyles, moduleStyles]; - } - - // tslint:disable:no-any - override renderImpl() { - // If more than one model is selected, SpanGraphModule will be offset - // vertically due to the model name header, while this one won't be. - // So, add an offset so that the content still aligns when there is a - // SpanGraphGoldModule and a SpanGraphModule side-by-side. - const offsetForHeader = !this.appState.compareExamplesEnabled && - this.appState.currentModels.length > 1; - // clang-format off - return html` - ${offsetForHeader? html`
` : null} -
- ${ - renderTokenGroups( - this.goldDisplayData, this.dataSpec, - (this.constructor as any).orientation)} -
- `; - // clang-format on - } - // tslint:enable:no-any - - static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) { - const hasTokens = findSpecKeys(datasetSpec, Tokens).length > 0; - const hasSupportedPreds = - findSpecKeys(datasetSpec, supportedPredTypes).length > 0; - return (hasTokens && hasSupportedPreds); - } -} - -/** Model output module class. */ -@customElement('span-graph-module') -export class SpanGraphModule extends LitModule { - static override title = 'Structured Prediction (model preds)'; - static override duplicateForExampleComparison = true; - static override duplicateAsRow = false; - static override numCols = 4; - static override template = - (model: string, selectionServiceIndex: number, shouldReact: number) => html` - - `; - static orientation = 'horizontal'; - - @computed - get predSpec() { - return this.appState.getModelSpec(this.model).output; - } - - // This is updated with an API call, via a reaction. - @observable predDisplayData: Annotations = {}; - - private async updatePredDisplayData(input: IndexedInput|null) { - if (input === null) { - this.predDisplayData = {}; - } else { - const promise = this.apiService.getPreds( - [input], this.model, this.appState.currentDataset, - [Tokens, ...supportedPredTypes]); - - const results = await this.loadLatest('getPreds', promise); - if (!results) return; - - this.predDisplayData = parseInput(results[0], this.predSpec); - } - } - - static override get styles() { - return [sharedStyles, moduleStyles]; - } - - override firstUpdated() { - this.reactImmediately( - () => this.selectionService.primarySelectedInputData, input => { - this.updatePredDisplayData(input); - }); - } - - // tslint:disable:no-any - override renderImpl() { - return html` -
- ${ - renderTokenGroups( - this.predDisplayData, this.predSpec, - (this.constructor as any).orientation)} -
- `; - } - // tslint:enable:no-any - - static override shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) { - const models = Object.keys(modelSpecs); - for (let modelNum = 0; modelNum < models.length; modelNum++) { - const spec = modelSpecs[models[modelNum]].spec; - const hasTokens = findSpecKeys(spec.output, Tokens).length > 0; - const hasSupportedPreds = - findSpecKeys(spec.output, supportedPredTypes).length > 0; - if (hasTokens && hasSupportedPreds) { - return true; - } - } - return false; - } -} - -// tslint:disable:class-as-namespace - -/** Gold predictions module class. */ -@customElement('span-graph-gold-module-vertical') -export class SpanGraphGoldModuleVertical extends SpanGraphGoldModule { - static override duplicateAsRow = true; - static override orientation = 'vertical'; - static override numCols = 4; - static override template = - (model: string, selectionServiceIndex: number, shouldReact: number) => html` - - `; -} - -/** Model output module class. */ -@customElement('span-graph-module-vertical') -export class SpanGraphModuleVertical extends SpanGraphModule { - static override duplicateAsRow = true; - static override orientation = 'vertical'; - static override template = - (model: string, selectionServiceIndex: number, shouldReact: number) => html` - - `; -} - -// tslint:enable:class-as-namespace - -declare global { - interface HTMLElementTagNameMap { - 'span-graph-gold-module': SpanGraphGoldModule; - 'span-graph-module': SpanGraphModule; - // TODO(b/172979677): make these parameterized versions, rather than - // separate classes. - 'span-graph-gold-module-vertical': SpanGraphGoldModuleVertical; - 'span-graph-module-vertical': SpanGraphModuleVertical; - } -}